Reputation: 11638
I noticed that whenever you create a new net extending torch.nn.Module
, you can immediately call net.parameters()
to find the parameters relevant for backpropagation.
import torch
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
net = MyNet()
print(list(net.parameters()))
But then I wondered, how is this even possible? I just assigned this Linear
layer object to a member variable but it is not recorded anywhere else (or is it?). Somehow MyNet
must be able to keep track of the parameters used but how?
Upvotes: 3
Views: 3566
Reputation: 1992
It's simple really, just go through attributes via meta-programming and check their type
class Example():
def __init__(self):
self.special_thing = nn.Parameter(torch.rand(2))
self.something_else = "ok"
def get_parameters(self):
for key, value in self.__dict__.items():
if type(value) == nn.Parameter:
print(key, "is a parameter!")
e = Example()
e.get_parameters()
# => special_thing is a parameter!
Upvotes: 4