obadul024
obadul024

Reputation: 303

What is self referring to in this PyTorch derived nn.Module class method?

I am following this tutorial for Pytorch and there is a line of code that makes no sense to me in the derived class MnistModule method training_step of the nn.Module class.

The line is out = self(images)

Please can someone explain to me what is happening here? Is this correct or not and if this is convention to follow.

Thanks

Here's the snippet


class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        print(type(out))
        return loss

Upvotes: 0

Views: 468

Answers (1)

chepner
chepner

Reputation: 532313

It refers to an instance of MnistModel, the same as in any other method defined by the class. The only thing odd is that self is called, but that's explained by the fact that nn.Module defines __call__, so all instances of MnistModel are themselves callable.

out = self(images) is equivalent to out = self.__call__(images).

Upvotes: 3

Related Questions