Dr__Soul
Dr__Soul

Reputation: 111

Pretrained Tensorflow model RGB -> RGBY channel extension

I am working on the protein analysis project. We receive the images* of proteins with 4 filters (Red, Green, Blue and Yellow). Every of those RGBY channels contains unique data as different cellular structures are visible with different filters.

The idea is to use a pre-trained network e.g. VGG19 and extend the number of channels from default 3 to 4. Something like this:

(My appologies, I am not allowed to add images directly before 10 reputation, please press the "Run code snippet" button to visualize):

<img src="https://i.sstatic.net/TZKka.png" alt="Italian Trulli">

Picture: VGG model with RGB extended to RGBY

The Y channel should be the copy of the existing pretrained channel. Then it is possible to make use of the pretrained weights.

Does anyone have an idea of how such extension of a pretrained network can be achieved?

* Author of the collage - Allunia from Kaggle, "Protein Atlas - Exploration and Baseline" kernel.

Upvotes: 3

Views: 1863

Answers (2)

mjkvaak
mjkvaak

Reputation: 549

Beyond the RGBY case, the following snippet works generally by copying or removing the layer's weights and/or biases vectors dimensions as needed. Please refer to numpy documentation on what numpy.resize does: in the case of the original question it copies the B-channel weights onto the Y-channel (or more generally onto any higher dimensionality).

import numpy as np
import tensorflow as tf
...

model = ...  # your RGBY model is here
pretrained_model = tf.keras.models.load_model(...)  # pretrained RGB model

# the following assumes that the layers match with the two models and
# only the shapes of weights and/or biases are different
for pretrained_layer, layer in zip(pretrained_model.layers, model.layers):
    pretrained = pretrained_layer.get_weights()
    target = layer.get_weights()
    if len(pretrained) == 0:  # skip input, pooling and other no weights layers
        continue
    try:  
        # set the pretrained weights as is whenever possible
        layer.set_weights(pretrained)
    except:
        # numpy.resize to the rescue whenever there is a shape mismatch
        for idx, (l1, l2) in enumerate(zip(pretrained, target)):
            target[idx] = np.resize(l1, l2.shape)

        layer.set_weights(target)

Upvotes: 0

Dr__Soul
Dr__Soul

Reputation: 111

Use the layer.get_weights() and layer.set_weights() functions of Keras api.

Create a template structure for 4-layers VGG (set input shape=(width, height, 4)). Then load the weights from 3-channel RGB model into 4-channel as RGBB.

Below is the code that does the procedure. In case of sequential VGG, the only layer that needs to be modified is the first Convolution layer. The structure of the subsequent layers is independent on the number of channels.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from keras.applications.vgg19 import VGG19
from keras.models import Model

vgg19 = VGG19(weights='imagenet')
vgg19.summary() # To check which layers will be omitted in 'pretrained' model

# Load part of the VGG without the top layers into 'pretrained' model
pretrained = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_pool').output)
pretrained.summary()

#%% Prepare model template with 4 input channels
config = pretrained.get_config() # run config['layers'][i] for reference
                                 # to restore layer-by layer structure

from keras.layers import Input, Conv2D, MaxPooling2D
from keras import optimizers

# For training from scratch change kernel_initializer to e.g.'VarianceScaling'
inputs = Input(shape=(224, 224, 4), name='input_17')
# block 1
x = Conv2D(64, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block1_conv1')(inputs)
x = Conv2D(64, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block1_conv2')(x)
x = MaxPooling2D(pool_size=(2, 2), name='block1_pool')(x)

# block 2
x = Conv2D(128, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block2_conv1')(x)
x = Conv2D(128, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block2_conv2')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block2_pool')(x)

# block 3
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv1')(x)
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv2')(x)
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv3')(x)
x = Conv2D(256, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block3_conv4')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block3_pool')(x)

# block 4
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv1')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv2')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv3')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block4_conv4')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block4_pool')(x)

# block 5
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv1')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv2')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv3')(x)
x = Conv2D(512, (3,3), padding='same', activation='relu', kernel_initializer='zeros', name='block5_conv4')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2,2), name='block5_pool')(x)

vgg_template = Model(inputs=inputs, outputs=x)

vgg_template.compile(optimizer=optimizers.RMSprop(lr=2e-4),
                     loss='categorical_crossentropy',
                     metrics=['acc'])


#%% Rewrite the weight loading/modification function
import numpy as np

layers_to_modify = ['block1_conv1'] # Turns out the only layer that changes
                                    # shape due to 4th channel is the first
                                    # convolution layer.

for layer in pretrained.layers: # pretrained Model and template have the same
                                # layers, so it doesn't matter which to 
                                # iterate over.

    if layer.get_weights() != []: # Skip input, pooling and no weights layers

        target_layer = vgg_template.get_layer(name=layer.name)

        if layer.name in layers_to_modify:

            kernels = layer.get_weights()[0]
            biases  = layer.get_weights()[1]

            kernels_extra_channel = np.concatenate((kernels,
                                                    kernels[:,:,-1:,:]),
                                                    axis=-2) # For channels_last

            target_layer.set_weights([kernels_extra_channel, biases])

        else:
            target_layer.set_weights(layer.get_weights())


#%% Save 4 channel model populated with weights for futher use    

vgg_template.save('vgg19_modified_clear.hdf5')

Upvotes: 7

Related Questions