Kraken
Kraken

Reputation: 1

Modifying individual layers of a Tensorflow model

I'm trying to build a model using transfer learning from inception-resnet-V2 with imagenet weights. This is part of my code to build the model

input_img_shape = (512, 512, 3)
# inception_resnet_v2 preprocessor: 
preprocessor = tf.keras.applications.inception_resnet_v2.preprocess_input
# base inception_resnet_v2 model
base_model = tf.keras.applications.InceptionResNetV2(weights='imagenet', include_top=False, 
input_shape=input_img_shape, pooling='avg')

If I check the summary

base_model.summary()

I get this as my model parameters (I've omitted the initial layers here):

.
.
.
conv_7b (Conv2D)                (None, 14, 14, 1536) 3194880     block8_10[0][0]                  
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608        conv_7b[0][0]                    
__________________________________________________________________________________________________
conv_7b_ac (Activation)         (None, 14, 14, 1536) 0           conv_7b_bn[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536)         0           conv_7b_ac[0][0]                 
==================================================================================================
Total params: 54,336,736
Trainable params: 54,276,192
Non-trainable params: 60,544

I want to make the base_model batch Normalization layer as non trainable. I use the following code

for layer in base_model.layers:
    if isinstance(layer, tf.keras.layers.BatchNormalization):
        layer.trainable = False

I get the following as my model

base_model.summary()


.
.
.
.
conv_7b (Conv2D)                (None, 14, 14, 1536) 3194880     block8_10[0][0]                  
__________________________________________________________________________________________________
conv_7b_bn (BatchNormalization) (None, 14, 14, 1536) 4608        conv_7b[0][0]                    
__________________________________________________________________________________________________
conv_7b_ac (Activation)         (None, 14, 14, 1536) 0           conv_7b_bn[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1536)         0           conv_7b_ac[0][0]                 
==================================================================================================
Total params: 54,336,736
Trainable params: 54,245,920
Non-trainable params: 90,816

I can see that the non-trainable parameters have increased by ~30K. This is correct.

Now I want to add another layer to this model as follows:


output_layer = tf.keras.layers.Dense(1, activation='sigmoid') 

# put them together
i = tf.keras.layers.Input([None, None, input_img_shape[2]], dtype = tf.uint8)
x = tf.cast(i, tf.float32)
x = preprocessor(x) 
x = base_model(x)
x = output_layer(x)
model = tf.keras.Model(inputs=[i], outputs=[x])

Model Summary:

model.summary()

Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
tf_op_layer_Cast (TensorFlow [(None, None, None, 3)]   0         
_________________________________________________________________
tf_op_layer_RealDiv (TensorF [(None, None, None, 3)]   0         
_________________________________________________________________
tf_op_layer_Sub (TensorFlowO [(None, 512, 512, 3)]     0         
_________________________________________________________________
inception_resnet_v2 (Functio (None, 1536)              54336736  
_________________________________________________________________
dense (Dense)                (None, 1)                 1537      
=================================================================
Total params: 54,338,273
Trainable params: 54,247,457
Non-trainable params: 90,816

This is working fine till now. Now I need to be able to loop through the layers of this model and set tf.keras.layers.BatchNormalization.trainable = True or False (For my use-case). I need to save this model and reload and do the same as well - tf.keras.layers.BatchNormalization.trainable = True or False.

Modifying the base_model variable might reflect the changes to model variable but I cannot do this after saving and reloading the model.

So I need a way to loop through all the layers of model and set BatchNormalization.trainable = True or False. I don't want to set all the layers as trainable. Using the same code which I used for base_model didn't work. I can no longer use this code with model because my inception archicture shows as just a layer in model summary.

How can I loop through the layers of model and modify individual layers?

Upvotes: 0

Views: 362

Answers (1)

Lescurel
Lescurel

Reputation: 11631

You can use a recursive function, checking if the type of the layer is a tf.keras.Model:

def set_batch_norm_trainable(model, trainable=False):
    for layer in base_model.layers:
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            layer.trainable = trainable
        if isinstance(layer, tf.keras.models.Model):
            set_batch_norm_trainable(layer, trainable)

Upvotes: 1

Related Questions