Reputation: 1743
Let's say I feed 3 grayscale images to a CNN, having a combined shape of 3,28,28. This process will generate multiple feature maps for each image. How do I identify which feature map corresponds to a particular image.
Here is some code -
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
print("Shape of x = ", x.shape)
x = self.pool(F.relu(self.conv2(x)))
print("Shape of x = ", x.shape)
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
foo = torch.randn(3,1, 28, 28)
foo_cnn = net(foo)
For instance, the first convolution generated 6 feature maps from 3 images. Is there a way for me to identify which feature map belonged to which image, so that I can perform some operation on it.
Upvotes: 3
Views: 256
Reputation: 1404
To distinguish which image generated which convolved feature maps, one must split the different input images into the batches dimension (#images=#batches
), such that when applying any convolutional layers, they're applied on each image separately, not a weighted sum of the different input images as would be the case if they were split into the channel/depth dimension.
Right now you're not feeding 3 images into the model (in pytorch's eyes); that would require the input to be of the shape: (3, 1, 28, 28)
for grayscale images and (3, 3, 28, 28)
for RGB images. What you're doing instead is (in a sense) concatenating the 3 images into the depth dimension resulting in the shape: (1, 3, 28, 28)
, thus the 6 output feature maps cannot be attributed to a specific image (a weighted combination of the 3, since they're in depth dimension).
Therefore, reshaping the input to (3, 1, 28, 28)
and changing conv1
to (1, 6, 5)
will result in the following output: (3, 6, 12, 12)
and hence, the 1st
6 feature maps in the 1st
batch (of the output) correspond to the first image in the batch (of the input), and the 2nd
6 feature maps correspond to the 2nd
image in the batch and so on.
Upvotes: 2