Reputation: 107
I am considering the sample code from the documentation:
import torch
from torch import nn
#
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())
The output is :
torch.Size([128, 30])
The constructor of Linear is :
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
This is consistent with the way the instance is created, i.e.:
m = nn.Linear(20, 30)
However, when m is used, it receives a tensor
output = m(input)
as input. I do not understand why. Where is this tensor defined in the source code?
Upvotes: 0
Views: 250
Reputation: 12827
When you do m(input)
, the __call__
(what is __call__
?) method is called, which internally calls forward
method and does other stuff. This logic is written in the base class: nn.Module
. For simplicity, assume for now, that doing m(input)
is equivalent to m.forward(input)
.
And what's the input to forward
? A tensor.
def forward(self, input: Tensor) -> Tensor
Upvotes: 2