Reputation: 75
I'm pretty new to CNN and have been following the below code. I'm not able to understand how and why have we selected the each argument of Conv2d() and nn.Linear () as they are i.e. the output, filter, channels, weights,padding and stride. I do understand the meaning of each though. Can someone very succinctly explain the flow for each layer? (Input Image Size is 32*32*3)
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 4 * 4, 500)
self.fc2 = nn.Linear(500, 10)
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 64 * 4 * 4)
x = self.dropout(x)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
Upvotes: 1
Views: 1712
Reputation: 114946
I think you'll find receptive field arithmetics useful for your understanding.
Your net has 3 convolution layers each with kernel size of 3x3 and padding of 1 pixels, which means that the spatial output of your convolution layers is the same as their input.
Each conv layer is followed by a max pooling with stride 2, that is, it reduces the spatial dimensions by a factor of 2.
So, in the spatial domain, you have an input of size 32x32 after first conv and pool its dimensions are 16x16, after the second conv and pool it is 8x8 and after the third conv+pool it is 4x4.
As for the "feature"/"channel" dimension: the input has 3 channels. The first conv layer has 16 filters ("out_channels=16"
) then 32 and finally 64.
Thus, after three conv layers your feature map has 64 channels (per spatial location).
Overall, an input of size 3x32x32 becomes 64x4x4 after the three conv+pooling layers defined by your network.
a nn.Linear
layer does not assign "spatial" meaning to its inputs and expects a 1D input (per entry in a minibatch), thus your forward
function "eliminates" the spatial dimensions and converts x
to a 1D vector using the x.view(-1, 64 * 4 * 4)
command.
Upvotes: 5