Reputation: 10531
This code:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
input = torch.randn(1,1,32,32)
out = net(input)
print(out)
I am learning python and trying to understand how this constructor works. My questions are these 2 lines:
input = torch.randn(1,1,32,32)
out = net(input)
In the init initialization, I can't see how the 'input' is used for initialization.
Upvotes: 0
Views: 256
Reputation: 12920
you are not passing an argument to a Class
you are passing an argument to object
. There is difference.
Following example shows how to achieve this. You need to implement __call__
method.
class CallableClass:
def __init__(self):
pass
def __call__(self, *args, **kwargs):
print(args)
class Net(CallableClass):
def __init__(self):
super(Net, self).__init__()
pass
net = Net()
net(100)
Upvotes: 0
Reputation: 248
net = Net()
calls the __init__
method without a argument.
out = net(input)
calls the __call__
method with input
as argument.
Since Net
does not implemented this, it must be implemented in the base class nn.Module
here you can find the sources of nn.Module
and there is __call__
defined with input
as parameter.
Upvotes: 1