Oliver Wilken
Oliver Wilken

Reputation: 2714

A proper way to adjust input size of CNN (e.g. VGG)

I want to train VGG on 128x128-sized images. I don't want to rescale them to 224x224 to save GPU-memory and training time. What would be the proper way to do so?

Upvotes: 3

Views: 2443

Answers (1)

Oliver Wilken
Oliver Wilken

Reputation: 2714

The best way is to keep the convolutional part as it is and replace the fully connected layers. This way it is even possible to take pretrained weights for the convolutional part of the network. The fully connected layers must be randomly initialized. This way one can finetune a network with a smaller input size.

Here some pytorch code

import torch
from torch.autograd import Variable
import torchvision
import torch.nn as nn

from torchvision.models.vgg import model_urls

VGG_TYPES = {'vgg11' : torchvision.models.vgg11, 
             'vgg11_bn' : torchvision.models.vgg11_bn, 
             'vgg13' : torchvision.models.vgg13, 
             'vgg13_bn' : torchvision.models.vgg13_bn, 
             'vgg16' : torchvision.models.vgg16, 
             'vgg16_bn' : torchvision.models.vgg16_bn,
             'vgg19_bn' : torchvision.models.vgg19_bn, 
             'vgg19' : torchvision.models.vgg19}


class Custom_VGG(nn.Module):

    def __init__(self,
                 ipt_size=(128, 128), 
                 pretrained=True, 
                 vgg_type='vgg19_bn', 
                 num_classes=1000):
        super(Custom_VGG, self).__init__()

        # load convolutional part of vgg
        assert vgg_type in VGG_TYPES, "Unknown vgg_type '{}'".format(vgg_type)
        vgg_loader = VGG_TYPES[vgg_type]
        vgg = vgg_loader(pretrained=pretrained)
        self.features = vgg.features

        # init fully connected part of vgg
        test_ipt = Variable(torch.zeros(1,3,ipt_size[0],ipt_size[1]))
        test_out = vgg.features(test_ipt)
        self.n_features = test_out.size(1) * test_out.size(2) * test_out.size(3)
        self.classifier = nn.Sequential(nn.Linear(self.n_features, 4096),
                                        nn.ReLU(True),
                                        nn.Dropout(),
                                        nn.Linear(4096, 4096),
                                        nn.ReLU(True),
                                        nn.Dropout(),
                                        nn.Linear(4096, num_classes)
                                       )
        self._init_classifier_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _init_classifier_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

To create a vgg just call this:

vgg = Custom_VGG(ipt_size=(128, 128), pretrained=True)

Upvotes: 3

Related Questions