Dorki
Dorki

Reputation: 1207

no reference to Module.parameters() after using more than once

I have a class that inherits from torch.nn.Module,

now when I do this code:

d = net.parameters()
print(len(list(d)))
print(len(list(d)))
print(len(list(d)))

the output is:

10
0
0

So I have reference to the net.parameters() only once, whys that?

Then it apparently disappear.. I got this error while trying to make my own Optimizer, so I pass this net.parameters() as a parameter to my new class, and apparently I couldn't use it because of that odd situation.

Upvotes: 1

Views: 95

Answers (1)

Berriel
Berriel

Reputation: 13651

This is working as expected. Module.parameters() returns an iterator, more specifically, a Python generator. One thing about them is that you cannot rewind a generator. So, in the first list(d) call, you are actually "consuming" all the generator. Then, if you try to do that again, it will be empty.

If you're wondering, the .parameters() implementation can be seen here, and it is very simple:

def parameters(self, recurse=True):
    for name, param in self.named_parameters(recurse=recurse):
        yield param

Perhaps it is easier to wrap your mind around it with this toy example:

def g():
    for x in [0, 1, 2, 3, 4]:
        yield x

d = g()
print(list(d))  # prints: [0, 1, 2, 3, 4]
print(list(d))  # prints: []

Upvotes: 4

Related Questions