Reputation: 2714
I want to train VGG on 128x128-sized images. I don't want to rescale them to 224x224 to save GPU-memory and training time. What would be the proper way to do so?
Upvotes: 3
Views: 2443
Reputation: 2714
The best way is to keep the convolutional part as it is and replace the fully connected layers. This way it is even possible to take pretrained weights for the convolutional part of the network. The fully connected layers must be randomly initialized. This way one can finetune a network with a smaller input size.
Here some pytorch code
import torch
from torch.autograd import Variable
import torchvision
import torch.nn as nn
from torchvision.models.vgg import model_urls
VGG_TYPES = {'vgg11' : torchvision.models.vgg11,
'vgg11_bn' : torchvision.models.vgg11_bn,
'vgg13' : torchvision.models.vgg13,
'vgg13_bn' : torchvision.models.vgg13_bn,
'vgg16' : torchvision.models.vgg16,
'vgg16_bn' : torchvision.models.vgg16_bn,
'vgg19_bn' : torchvision.models.vgg19_bn,
'vgg19' : torchvision.models.vgg19}
class Custom_VGG(nn.Module):
def __init__(self,
ipt_size=(128, 128),
pretrained=True,
vgg_type='vgg19_bn',
num_classes=1000):
super(Custom_VGG, self).__init__()
# load convolutional part of vgg
assert vgg_type in VGG_TYPES, "Unknown vgg_type '{}'".format(vgg_type)
vgg_loader = VGG_TYPES[vgg_type]
vgg = vgg_loader(pretrained=pretrained)
self.features = vgg.features
# init fully connected part of vgg
test_ipt = Variable(torch.zeros(1,3,ipt_size[0],ipt_size[1]))
test_out = vgg.features(test_ipt)
self.n_features = test_out.size(1) * test_out.size(2) * test_out.size(3)
self.classifier = nn.Sequential(nn.Linear(self.n_features, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes)
)
self._init_classifier_weights()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _init_classifier_weights(self):
for m in self.classifier:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
To create a vgg just call this:
vgg = Custom_VGG(ipt_size=(128, 128), pretrained=True)
Upvotes: 3