Reputation: 15010
I wrote a pretrained vgg16 model for image classification and its layers are
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
After some initial hick-up its now working fine. I want to use this model for class activation mapping (CAM) for visualizing CNN outputs. I know that in order to do that first we have to get the activations of last convolutional layer in vgg16 then the weight matrix of the last fully connected layer and lastly take the dot product of the two.
First I got the class index for the query image using this code
model.eval()
pred = model(img1.float())
class_idx = torch.argmax(pred).detach().numpy().tolist()
classes[class_idx]
Then I fetched the input images last convolutional layer activations which is of the size torch.Size([1, 512, 14, 14])
last_conv_feat = torch.nn.Sequential(*list(model.features)[:30])
pred_a = last_conv_feat(img1.float())
print(pred_a.shape)
After this I extracted the weights of the fully connected layers of vgg16 classifier and it has a shape of torch.Size([1000, 4096])
model.classifier[6].weight.shape
From this weight matrix then I recovered the weight parameters for the relevant class index
w_idx = model.classifier[6].weight[class_idx] # torch.Size([4096])
The problem is the shape of the convolutional activation matrix and the fully connected layer doest match, one is [1, 512, 14, 14] and the other is [4096]. How do I take dot product of these two matrix and get the CAM output?
Upvotes: 2
Views: 1215
Reputation: 1240
This particular model isn't suitable for the simple approach you've pointed out. The CAM you refer to are extracted from models that have only one linear layer at the end, preceeded by a global average pooling layer, like this
features = MyConvolutions(x)
pooled_features = AveragePool(features)
predictions = Linear(pooled_features)
This typically works with ResNet architectures or one of their many derivates. Hence, my recommendation would be that unless there's a specific reason to use VGG you adopt a ResNet architecture.
------- EDIT -------
If you want to go with VGG, there are two options:
Upvotes: 2