Reputation: 1207
I have a class that inherits from torch.nn.Module,
now when I do this code:
d = net.parameters()
print(len(list(d)))
print(len(list(d)))
print(len(list(d)))
the output is:
10
0
0
So I have reference to the net.parameters() only once, whys that?
Then it apparently disappear.. I got this error while trying to make my own Optimizer, so I pass this net.parameters() as a parameter to my new class, and apparently I couldn't use it because of that odd situation.
Upvotes: 1
Views: 95
Reputation: 13651
This is working as expected. Module.parameters()
returns an iterator, more specifically, a Python generator. One thing about them is that you cannot rewind a generator. So, in the first list(d)
call, you are actually "consuming" all the generator. Then, if you try to do that again, it will be empty.
If you're wondering, the .parameters()
implementation can be seen here, and it is very simple:
def parameters(self, recurse=True):
for name, param in self.named_parameters(recurse=recurse):
yield param
Perhaps it is easier to wrap your mind around it with this toy example:
def g():
for x in [0, 1, 2, 3, 4]:
yield x
d = g()
print(list(d)) # prints: [0, 1, 2, 3, 4]
print(list(d)) # prints: []
Upvotes: 4