Reputation: 303
I am following this tutorial for Pytorch and there is a line of code that makes no sense to me in the derived class MnistModule
method training_step
of the nn.Module
class.
The line is
out = self(images)
Please can someone explain to me what is happening here? Is this correct or not and if this is convention to follow.
Thanks
Here's the snippet
class MnistModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, xb):
xb = xb.reshape(-1, 784)
out = self.linear(xb)
return out
def training_step(self, batch):
images, labels = batch
out = self(images) # Generate predictions
loss = F.cross_entropy(out, labels) # Calculate loss
print(type(out))
return loss
Upvotes: 0
Views: 468
Reputation: 532313
It refers to an instance of MnistModel
, the same as in any other method defined by the class. The only thing odd is that self
is called, but that's explained by the fact that nn.Module
defines __call__
, so all instances of MnistModel
are themselves callable.
out = self(images)
is equivalent to out = self.__call__(images)
.
Upvotes: 3