Reputation: 1
I'm trying to build a model using transfer learning from inception-resnet-V2 with imagenet weights. This is part of my code to build the model
input_img_shape = (512, 512, 3)
# inception_resnet_v2 preprocessor:
preprocessor = tf.keras.applications.inception_resnet_v2.preprocess_input
# base inception_resnet_v2 model
base_model = tf.keras.applications.InceptionResNetV2(weights='imagenet', include_top=False,
input_shape=input_img_shape, pooling='avg')
If I check the summary
base_model.summary()
I get this as my model parameters (I've omitted the initial layers here):
.
.
.
conv_7b (Conv2D) (None, 14, 14, 1536) 3194880 block8_10[0][0]
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608 conv_7b[0][0]
__________________________________________________________________________________________________
conv_7b_ac (Activation) (None, 14, 14, 1536) 0 conv_7b_bn[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536) 0 conv_7b_ac[0][0]
==================================================================================================
Total params: 54,336,736
Trainable params: 54,276,192
Non-trainable params: 60,544
I want to make the base_model batch Normalization layer as non trainable. I use the following code
for layer in base_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = False
I get the following as my model
base_model.summary()
.
.
.
.
conv_7b (Conv2D) (None, 14, 14, 1536) 3194880 block8_10[0][0]
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608 conv_7b[0][0]
__________________________________________________________________________________________________
conv_7b_ac (Activation) (None, 14, 14, 1536) 0 conv_7b_bn[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536) 0 conv_7b_ac[0][0]
==================================================================================================
Total params: 54,336,736
Trainable params: 54,245,920
Non-trainable params: 90,816
I can see that the non-trainable parameters have increased by ~30K. This is correct.
Now I want to add another layer to this model as follows:
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')
# put them together
i = tf.keras.layers.Input([None, None, input_img_shape[2]], dtype = tf.uint8)
x = tf.cast(i, tf.float32)
x = preprocessor(x)
x = base_model(x)
x = output_layer(x)
model = tf.keras.Model(inputs=[i], outputs=[x])
Model Summary:
model.summary()
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, None, None, 3)] 0
_________________________________________________________________
tf_op_layer_Cast (TensorFlow [(None, None, None, 3)] 0
_________________________________________________________________
tf_op_layer_RealDiv (TensorF [(None, None, None, 3)] 0
_________________________________________________________________
tf_op_layer_Sub (TensorFlowO [(None, 512, 512, 3)] 0
_________________________________________________________________
inception_resnet_v2 (Functio (None, 1536) 54336736
_________________________________________________________________
dense (Dense) (None, 1) 1537
=================================================================
Total params: 54,338,273
Trainable params: 54,247,457
Non-trainable params: 90,816
This is working fine till now. Now I need to be able to loop through the layers of this model and set tf.keras.layers.BatchNormalization.trainable = True or False (For my use-case). I need to save this model and reload and do the same as well - tf.keras.layers.BatchNormalization.trainable = True or False.
Modifying the base_model variable might reflect the changes to model variable but I cannot do this after saving and reloading the model.
So I need a way to loop through all the layers of model and set BatchNormalization.trainable = True or False. I don't want to set all the layers as trainable. Using the same code which I used for base_model didn't work. I can no longer use this code with model because my inception archicture shows as just a layer in model summary.
How can I loop through the layers of model and modify individual layers?
Upvotes: 0
Views: 362
Reputation: 11631
You can use a recursive function, checking if the type of the layer is a tf.keras.Model
:
def set_batch_norm_trainable(model, trainable=False):
for layer in base_model.layers:
if isinstance(layer, tf.keras.layers.BatchNormalization):
layer.trainable = trainable
if isinstance(layer, tf.keras.models.Model):
set_batch_norm_trainable(layer, trainable)
Upvotes: 1