Reputation: 36604
I want to re-use the pre-trained weights of MobiletNetv2
, but with images with 12 channels. I know this needs to create more weights, but that's okay because I want to re-train anyway. I can't find a way to make it work.
import tensorflow as tf
class CNN(tf.keras.Model):
def __init__(self):
super(CNN, self).__init__()
self.input_layer = tf.keras.layers.InputLayer(input_shape=(None, 224, 224, 12))
self.base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')
_ = self.base._layers.pop(0)
self.flat1 = tf.keras.layers.Flatten()
self.dens3 = tf.keras.layers.Dense(10)
def call(self, x, **kwargs):
x = self.input_layer(x)
x = self.base(x)
x = self.flat1(x)
x = self.dens3(x)
return x
model = CNN()
model.build(input_shape=(None, 224, 224, 12))
ValueError: Input 0 is incompatible with layer mobilenetv2_1.00_224: expected shape=(None, 224, 224, 3), found shape=(None, 224, 224, 12)
I tried popping the first layer like in other answers.
Upvotes: 1
Views: 1635
Reputation: 36604
It's possible to load two models, one with input shape with 12 channels, and the other one with the normal 12 channels. Then, just load the weights of the 3-channel model to the 12-channel model, starting with the 2nd or 3rd layer.
Here's where the weight transfer is performed:
for i in range(3, len(self.base.layers)):
self.base.layers[i].set_weights(base_weights.layers[i].get_weights())
Here's the whole thing:
import tensorflow as tf
h, w, c = 224, 224, 3
class CNNModel(tf.keras.Model):
def __init__(self):
super(CNNModel, self).__init__()
self.base = tf.keras.applications.MobileNetV2(input_shape=(h, w, 12),
include_top=False,
weights=None)
base_weights = tf.keras.applications.MobileNetV2(input_shape=(h, w, c),
include_top=False,
weights='imagenet')
for i in range(3, len(self.base.layers)):
self.base.layers[i].set_weights(base_weights.layers[i].get_weights())
del base_weights
self.pool = tf.keras.layers.GlobalAveragePooling2D()
self.drop1 = tf.keras.layers.Dropout(0.25)
self.out = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x, training=None, **kwargs):
x = self.base(x)
x = self.pool(x)
x = self.drop1(x)
x = self.out(x)
return x
model = CNNModel()
model.build(input_shape=(None, h, w, 12))
Upvotes: 2
Reputation: 17219
One of the easiest ways (in such a situation) is to pass the multi-channel input (H, W, C > 3
) to a Conv2D(3, 3, padding='same')
layer followed by the pretrained model.
class CNN(tf.keras.Model):
def __init__(self):
super(CNN, self).__init__()
self.base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')
self.conv = tf.keras.layers.Conv2D(3, 3, padding='same')
self.flat1 = tf.keras.layers.Flatten()
self.dens3 = tf.keras.layers.Dense(10)
def call(self, x, **kwargs):
x = self.conv(x)
x = self.base(x)
x = self.flat1(x)
x = self.dens3(x)
return x
def build_graph(self):
x = tf.keras.Input(shape=(224, 224, 12))
return tf.keras.Model(inputs=[x], outputs=self.call(x))
model = CNN()
model.build(input_shape=(None, 224, 224, 12))
It simply does the job.
model(tf.ones((1, 224, 224, 12))).shape # TensorShape([1, 10])
model.build_graph().summary()
Model: "model_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_13 (InputLayer) [(None, 224, 224, 12)] 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 224, 224, 3) 327
_________________________________________________________________
mobilenetv2_1.00_224 (Functi (None, 7, 7, 1280) 2257984
_________________________________________________________________
flatten_4 (Flatten) (None, 62720) 0
_________________________________________________________________
dense_4 (Dense) (None, 10) 627210
=================================================================
Total params: 2,885,521
Trainable params: 2,851,409
Non-trainable params: 34,112
_____________________________________________
Also, see this answer, it may help.
Upvotes: 1
Reputation: 22031
You can use each channel as a standalone image... You take each channel as an image of shape (224,224,1), repeat each channel in order to have an image (224,224,3). Finally, you feed all the 12 images to the same mobilenet.
This generates 12 outputs that you can average or combine in different ways.
This is only a possibility that I didn't test in practice but can also be a good starting point.
Pythonic speaking this is the code of what I mean:
def split_channels(image):
channels = tf.split(image, 12, axis=-1)
return [tf.repeat(c, [3], axis=-1) for c in channels]
inp = Input((224,224,12))
channels = Lambda(split_channels)(inp)
base_model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3), include_top=False, weights='imagenet')
avg = Average()([base_model(c) for c in channels])
x = Flatten()(avg)
out = Dense(10)(x)
model = Model(inp, out)
Upvotes: 0