S dB
S dB

Reputation: 11

How to change the first convolution of a pretrained ResNet in Tensorflow?

Hi I need to change the first convolution of a model from rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x3x64) to rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x4x64), so basicaly augmenting the number of filter form 3 to 4 to accept 4 channels images but keeping the pretrained weight elsewhere (just the additional channel initialize ramdonly).

Do you have an idea of how to do that in Tensorflow 1.x (I'm more of a PyTorch guy...) ?

In PyTorch I do:

net = model.resnet50(num_classes=dataset_train.num_classes(),pretrained=True)

new_conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2,padding=3,bias=False)  

conv1 = net.conv1

with torch.no_grad():

   new_conv1.weight[:, :3, :, :]= conv1.weight

   new_conv1.bias = conv1.bias

net.conv1 = new_conv1

Here is how the model is created in tensorflow:

def single_stream(self, images, modality, is_training, reuse=False):

    with tf.variable_scope(modality, reuse=reuse):
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            _, end_points = resnet_v1.resnet_v1_50(
                images, self.no_classes, is_training=is_training, reuse=reuse)

    # last bottleneck before logits
    net = end_points[modality + '/resnet_v1_50/block4']
    if 'autoencoder' in self.mode:
        return net

    with tf.variable_scope(modality + '/resnet_v1_50', reuse=reuse):
        bottleneck = slim.conv2d(net, self.hidden_repr_size, [
                                 7, 7], padding='VALID', activation_fn=tf.nn.relu, scope='f_repr')
        net = slim.conv2d(bottleneck, self.no_classes, [
                          1, 1], activation_fn=None, scope='_logits_')

    if ('train_hallucination' in self.mode or 'test_disc' in self.mode or 'train_eccv' in self.mode):
        return net, bottleneck

    return net

I am able with the command in the build_model: self.images = tf.placeholder(tf.float32, [None, 224, 224, 4], modality + '_images') to effectively change the 3 to a 4: rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x4x64) [12544, bytes: 50176] but the problem is thus now with the checkpoint!

Thanks a lot for your help!

Upvotes: 1

Views: 1304

Answers (1)

GianAnge
GianAnge

Reputation: 633

As you do with Pytorch, you can do the same in Keras, which is now a module of TF2 (more info).

I'm gonna show you one possible way to do so:

net_conv1 = model.layers[2] # first 2D convolutional layer, from model.layers, or model.summary()
# your new set of weights must have same dimensions of the ouput of the layer
print( 'weights shape: ', numpy.shape(net_conv1.weights) )
print( net_conv1.weights[0].shape )
print( net_conv1.weights[1].shape )
# New weights
osh_0 = net_conv1.weights[0].shape.as_list()
osh_1 = net_conv1.weights[1].shape.as_list()
print(osh_0, osh_1)
new_conv1_w_0 = numpy.random.rand( *osh_0 )
new_conv1_w_1 = numpy.random.rand( *osh_1 )
# update the weights
net_conv1.set_weights([new_conv1_w_0, new_conv1_w_1])
# check the result
net_conv1.get_weights()
# update the model
model.layers[2] = net_conv1

Check the layers section of Keras doc.

Hope it will be helpful

Upvotes: 1

Related Questions