Reputation: 97
in order to understand how this code works, I have written a small reproducer. How does the self.hidden variable use a variable x in the forward method?
enter code class Network(nn.Module):
def __init__(self):
super().__init__()
# Inputs to hidden layer linear transformation
self.hidden = nn.Linear(784, 256)
# Output layer, 10 units - one for each digit
self.output = nn.Linear(256, 10)
# Define sigmoid activation and softmax output
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
# Pass the input tensor through each of our operations
x = self.hidden(x)
x = self.sigmoid(x)
x = self.output(x)
x = self.softmax(x)
return x
Upvotes: 0
Views: 38
Reputation: 13601
You misunderstood what self.hidden = nn.Linear(784, 256)
does. You wrote that:
hidden
is defined as a function
but this is not true. self.hidden
is an object of the class nn.Linear
. And when you call self.hidden(...)
, you are not passing arguments to nn.Linear
; you are passing arguments to __call__
(defined in the nn.Linear
class).
If you want more details on that, I have expanded on how it works in PyTorch: see this answer.
Upvotes: 1