Nicolas Gervais
Nicolas Gervais

Reputation: 36604

Replacing the input layer of a pre-trained model with different channels?

I want to re-use the pre-trained weights of MobiletNetv2, but with images with 12 channels. I know this needs to create more weights, but that's okay because I want to re-train anyway. I can't find a way to make it work.

import tensorflow as tf

class CNN(tf.keras.Model):
    def __init__(self):
        super(CNN, self).__init__()
        self.input_layer = tf.keras.layers.InputLayer(input_shape=(None, 224, 224, 12))
        self.base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                                      include_top=False,
                                                      weights='imagenet')
        _ = self.base._layers.pop(0)
        self.flat1 = tf.keras.layers.Flatten()
        self.dens3 = tf.keras.layers.Dense(10)

    def call(self, x, **kwargs):
        x = self.input_layer(x)
        x = self.base(x)
        x = self.flat1(x)
        x = self.dens3(x)
        return x

model = CNN()
model.build(input_shape=(None, 224, 224, 12))

ValueError: Input 0 is incompatible with layer mobilenetv2_1.00_224: expected shape=(None, 224, 224, 3), found shape=(None, 224, 224, 12)

I tried popping the first layer like in other answers.

Upvotes: 1

Views: 1635

Answers (3)

Nicolas Gervais
Nicolas Gervais

Reputation: 36604

It's possible to load two models, one with input shape with 12 channels, and the other one with the normal 12 channels. Then, just load the weights of the 3-channel model to the 12-channel model, starting with the 2nd or 3rd layer.

Here's where the weight transfer is performed:

for i in range(3, len(self.base.layers)):
            self.base.layers[i].set_weights(base_weights.layers[i].get_weights())

Here's the whole thing:

import tensorflow as tf

h, w, c = 224, 224, 3


class CNNModel(tf.keras.Model):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.base = tf.keras.applications.MobileNetV2(input_shape=(h, w, 12),
                                                      include_top=False,
                                                      weights=None)
        base_weights = tf.keras.applications.MobileNetV2(input_shape=(h, w, c),
                                                         include_top=False,
                                                         weights='imagenet')

        for i in range(3, len(self.base.layers)):
            self.base.layers[i].set_weights(base_weights.layers[i].get_weights())

        del base_weights
        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.drop1 = tf.keras.layers.Dropout(0.25)
        self.out = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, x, training=None, **kwargs):
        x = self.base(x)
        x = self.pool(x)
        x = self.drop1(x)
        x = self.out(x)
        return x


model = CNNModel()

model.build(input_shape=(None, h, w, 12))

Upvotes: 2

Innat
Innat

Reputation: 17219

One of the easiest ways (in such a situation) is to pass the multi-channel input (H, W, C > 3) to a Conv2D(3, 3, padding='same') layer followed by the pretrained model.

class CNN(tf.keras.Model):
    def __init__(self):
        super(CNN, self).__init__()
        self.base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                                      include_top=False,
                                                      weights='imagenet')
        self.conv = tf.keras.layers.Conv2D(3, 3, padding='same')
        self.flat1 = tf.keras.layers.Flatten()
        self.dens3 = tf.keras.layers.Dense(10)

    def call(self, x, **kwargs):
        x = self.conv(x)
        x = self.base(x)
        x = self.flat1(x)
        x = self.dens3(x)
        return x
    
    def build_graph(self):
        x = tf.keras.Input(shape=(224, 224, 12))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))

model = CNN()
model.build(input_shape=(None, 224, 224, 12))

It simply does the job.

model(tf.ones((1, 224, 224, 12))).shape # TensorShape([1, 10])
model.build_graph().summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_13 (InputLayer)        [(None, 224, 224, 12)]    0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 224, 224, 3)       327       
_________________________________________________________________
mobilenetv2_1.00_224 (Functi (None, 7, 7, 1280)        2257984   
_________________________________________________________________
flatten_4 (Flatten)          (None, 62720)             0         
_________________________________________________________________
dense_4 (Dense)              (None, 10)                627210    
=================================================================
Total params: 2,885,521
Trainable params: 2,851,409
Non-trainable params: 34,112
_____________________________________________

Also, see this answer, it may help.

Upvotes: 1

Marco Cerliani
Marco Cerliani

Reputation: 22031

You can use each channel as a standalone image... You take each channel as an image of shape (224,224,1), repeat each channel in order to have an image (224,224,3). Finally, you feed all the 12 images to the same mobilenet.

This generates 12 outputs that you can average or combine in different ways.

This is only a possibility that I didn't test in practice but can also be a good starting point.

Pythonic speaking this is the code of what I mean:

def split_channels(image):
    channels = tf.split(image, 12, axis=-1)
    return [tf.repeat(c, [3], axis=-1) for c in channels]
    
inp = Input((224,224,12))
channels = Lambda(split_channels)(inp)
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3), include_top=False, weights='imagenet')
avg = Average()([base_model(c) for c in channels])
x = Flatten()(avg)
out = Dense(10)(x)

model = Model(inp, out)

Upvotes: 0

Related Questions