Frank
Frank

Reputation: 169

Why do we pass nn.Module as an argument to class definition for neural nets?

I want to understand why we pass torch.nn.Module as a argument when we define the class for a neural network like GAN's

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

Upvotes: 5

Views: 3085

Answers (2)

prosti
prosti

Reputation: 46331

This line

class Generator(nn.Module):

simple means the Generator class will inherit the nn.Module class, it is not an argument.

However, the dunder init method:

def __init__(self, input_size, hidden_size, output_size, f):

Has self which is why you may consider this as an argument.

Well this is Python class instance self. There were tinkering battles should it stay or should it go, but Guido, explained in his blog why it has to stay.

Upvotes: 4

Nicolas Essipova
Nicolas Essipova

Reputation: 13

We are essentially defining the class 'Generator' with the nn.Module (with its functionalities). In programming we refer to this as inheritence (with the super(Generator, self).__init__()).

You can read more here: https://www.w3schools.com/python/python_inheritance.asp

Upvotes: 1

Related Questions