LP0
LP0

Reputation: 107

understanding how PyTorch Linear works

I am considering the sample code from the documentation:

import torch
from torch import nn
#
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())

The output is :

torch.Size([128, 30])

The constructor of Linear is :

def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:

This is consistent with the way the instance is created, i.e.:

m = nn.Linear(20, 30)

However, when m is used, it receives a tensor

output = m(input)

as input. I do not understand why. Where is this tensor defined in the source code?

Upvotes: 0

Views: 250

Answers (1)

Harshit Kumar
Harshit Kumar

Reputation: 12827

When you do m(input), the __call__ (what is __call__?) method is called, which internally calls forward method and does other stuff. This logic is written in the base class: nn.Module. For simplicity, assume for now, that doing m(input) is equivalent to m.forward(input).

And what's the input to forward? A tensor.

def forward(self, input: Tensor) -> Tensor

Upvotes: 2

Related Questions