Reputation: 1199
I am trying to learn PyTorch. But I am really confused about the shape in a fully connected layer after convolution and max pooling.
Case 1. How we calculate the number 5408 in nn.Linear
I think 5408 = 32 * m * m, where 32 comes from the nn.Conv2d(3, **32**, kernel_size=7, stride=2)
, but then m would equals 13. However, where is the 13 comes from?
simple_model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
Flatten(), # see above for explanation
nn.Linear(5408, 10), # affine layer
)
Case 2 How we get the number 4*4 in fc = nn.Linear(64*4*4, 10)
, the same problem in case 1. I don't know where the number 4 comes from...
# (conv -> batchnorm -> relu -> maxpool) * 3 -> fc
layer1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size = 5, padding = 2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2)
)
layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size = 3, padding = 1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2)
)
layer3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size = 3, padding = 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
fc = nn.Linear(64*4*4, 10)
Upvotes: 0
Views: 1406
Reputation: 9426
Here is a good primer (specifically the summary) on calculating these sorts of things: http://cs231n.github.io/convolutional-networks/
Where
You haven't mentioned your input width/height, but I'm assuming they're 28x28
MNIST images.
In that case we have:
28
7
2
2
Plugging those numbers into the above equation will give you 13.5
which is awkward because it's not an integer. In the case of PyTorch it seems to round it down to 13. (It's actually proven kind of hard to find any documentation of this fact other than this forum post)
Edit: The actual implementation for cuDNN is here: https://github.com/pytorch/pytorch/blob/fdab1cf0d485820907d7541266d69b70e1d3d16b/aten/src/ATen/native/cudnn/Conv.cpp#L157-L158
For your second case, it seems like your inputs are not 28x28
and must be 32x32
. The convolutions don't shrink the height and width (you can plug the numbers in yourself and check). However the MaxPool2d(2)
layer shrinks the height and width by half after each convolution. So you go from:
32x32
->16x16
->8x8
-> 4x4
Upvotes: 2