Reputation: 169
I want to understand why we pass torch.nn.Module as a argument when we define the class for a neural network like GAN's
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super(Generator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
Upvotes: 5
Views: 3085
Reputation: 46331
This line
class Generator(nn.Module):
simple means the Generator
class will inherit the nn.Module
class, it is not an argument.
However, the dunder init method:
def __init__(self, input_size, hidden_size, output_size, f):
Has self which is why you may consider this as an argument.
Well this is Python class instance self
. There were tinkering battles should it stay or should it go, but Guido, explained in his blog why it has to stay
.
Upvotes: 4
Reputation: 13
We are essentially defining the class 'Generator' with the nn.Module (with its functionalities). In programming we refer to this as inheritence (with the super(Generator, self).__init__()
).
You can read more here: https://www.w3schools.com/python/python_inheritance.asp
Upvotes: 1