flawr
flawr

Reputation: 11638

How does Module.parameters() find the parameters?

I noticed that whenever you create a new net extending torch.nn.Module, you can immediately call net.parameters() to find the parameters relevant for backpropagation.

import torch

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc = torch.nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x)

net = MyNet()
print(list(net.parameters()))

But then I wondered, how is this even possible? I just assigned this Linear layer object to a member variable but it is not recorded anywhere else (or is it?). Somehow MyNet must be able to keep track of the parameters used but how?

Upvotes: 3

Views: 3566

Answers (1)

Coolness
Coolness

Reputation: 1992

It's simple really, just go through attributes via meta-programming and check their type

class Example():
    def __init__(self):
        self.special_thing = nn.Parameter(torch.rand(2))
        self.something_else = "ok"

    def get_parameters(self):
        for key, value in self.__dict__.items():
            if type(value) == nn.Parameter:
                print(key, "is a parameter!")


e = Example()
e.get_parameters()
# => special_thing is a parameter!

Upvotes: 4

Related Questions