Agnosie
Agnosie

Reputation: 65

How to freeze/unfreeze a pretrained Model as part of a subclassed Model in Tensorflow?

I am trying to build a subclassed Model which consists of a pretrained convolutional Base and some Dense Layers on top, using Tensorflow >= 2.4. However freezing/unfreezing of the subclassed Model has no effect once it was trained before. When I do the same with the Functional API everything works as expected. I would really appreciate some Hint to what im missing here: Following Code should specify my problem further. Pardon me the amount of Code:

#Setup


import tensorflow as tf
tf.config.run_functions_eagerly(False)
 
import numpy as np
from tensorflow.keras.regularizers import l1 
import matplotlib.pyplot as plt


@tf.function
def create_images_and_labels(img,label, height = 70, width = 70): #Image augmentation

    label = tf.cast(label, 'float32')
    label = tf.squeeze(label)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, (height, width))
  #  img = preprocess_input(img)
    
    return img, label



cifar = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test)  = cifar.load_data()
num_classes = len(np.unique(y_train))

ds_train = tf.data.Dataset.from_tensor_slices((x_train, tf.one_hot(y_train, depth = len(np.unique(y_train)))))
ds_train = ds_train.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))

ds_train = ds_train.shuffle(50000)
ds_train = ds_train.batch(50, drop_remainder = True)


ds_val = tf.data.Dataset.from_tensor_slices((x_test, tf.one_hot(y_test, depth = len(np.unique(y_train)))))
ds_val = ds_val.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))
ds_val = ds_val.batch(50, drop_remainder=True)



# for i in ds_train.take(1):
#     x, y = i
#     for ind in range(x.shape[0]):
#         plt.imshow(x[ind,:,:])
#         plt.show()
#         print(y[ind])


'''
Defining simple subclassed Model consisting of 
VGG16
Flatten
Dense Layers

customized what happens in model.fit and model.evaluate (Actually its the standard Keras procedure with custom Metrics)
customized metrics: Loss and Accuracy for Training and Validation Step

added unfreezing Method 
'set_trainable_layers'
Arguments: 
    num_head (How many dense Layers)
    num_base (How many VGG Layers)
'''



class Test_Model(tf.keras.models.Model):
    
    def __init__(
            self,
            num_unfrozen_head_layers, 
            num_unfrozen_base_layers,
            num_classes,
            conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3)),
            
            
            
            ):
                super(Test_Model, self).__init__(name = "Test_Model")
                
                self.base = conv_base
                self.flatten = tf.keras.layers.Flatten()
                self.dense1 = tf.keras.layers.Dense(2048, activation = 'relu')
                self.dense2 = tf.keras.layers.Dense(1024, activation = 'relu')
                self.dense3 = tf.keras.layers.Dense(128, activation = 'relu')
                self.out = tf.keras.layers.Dense(num_classes, activation = 'softmax')
                self.out._name = 'out'



                    
                self.train_loss_metric = tf.keras.metrics.Mean('Supervised Training Loss')
                self.train_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Training Accuracy')
                self.val_loss_metric = tf.keras.metrics.Mean('Supervised Validation Loss')
                self.val_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Validation Accuracy')
                self.loss_fn = tf.keras.losses.categorical_crossentropy
                self.learning_rate = 1e-4
                
              #  self.build((None, 32,32,3))
                self.set_trainable_layers(num_unfrozen_head_layers, num_unfrozen_base_layers)
        
                
    @tf.function
    def call(self, inputs, training = False):
        x = self.base(inputs)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.out(x)
        return x
    @tf.function
    def train_step(self, input_data):
         x_batch, y_batch = input_data
         with tf.GradientTape() as tape: 
                tape.watch(x_batch)
                y_pred = self(x_batch, training = True)
                loss = self.loss_fn(y_batch, y_pred)
                
         trainable_vars = self.trainable_weights
         gradients = tape.gradient(loss, trainable_vars)
            
         self.optimizer.apply_gradients(zip(gradients, trainable_vars))
         self.train_loss_metric.update_state(loss)
         self.train_acc_metric.update_state(y_batch, y_pred)
            
         return {"Supervised Loss": self.train_loss_metric.result(),
                 "Supervised Accuracy":self.train_acc_metric.result()}
     
    @tf.function
    def test_step(self, input_data):
        x_batch,y_batch = input_data
        y_pred = self(x_batch, training = False)
        loss = self.loss_fn(y_batch, y_pred)
        
        self.val_loss_metric.update_state(loss)
        self.val_acc_metric.update_state(y_batch, y_pred)
            
        return {"Val Supervised Loss": self.val_loss_metric.result(),
                "Val Supervised Accuracy":self.val_acc_metric.result()}
    
    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [self.train_loss_metric,
                self.train_acc_metric,
                self.val_loss_metric,
                self.val_acc_metric]  
    
    def set_trainable_layers(self, num_head, num_base):
        
        for layer in [lay for lay in self.layers if not isinstance(lay , tf.keras.models.Model)]: 
            layer.trainable = False
            print(layer.name, layer.trainable)
        for block in self.layers:
            
            if isinstance(block, tf.keras.models.Model):
                print('Found Submodel', block.name)
                for layer in block.layers: 
                    layer.trainable = False
                    print(layer.name, layer.trainable)
                if num_base > 0:    
                    for layer in block.layers[-num_base:]:
                        layer.trainable = True
                        print(layer.name, layer.trainable)
        if num_head > 0: 
            for layer in [lay for lay in self.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
                layer.trainable = True 
                print(layer.name, layer.trainable)
        
        
    
'''
Showcase1: First training completely frozen Model, then unfreezing: 
    unfreezed model doesnt learn
'''    
    
model = Test_Model(num_unfrozen_head_layers= 0, num_unfrozen_base_layers = 0, num_classes = num_classes)    # Should NOT learn -> doesnt learn
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)


model.set_trainable_layers(10,20) # SHOULD LEARN -> Doesnt learn
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
#DOESNT LEARN
    
'''
Showcase2: when first training the Model with more trainable Layers than in the second step:
    AssertionError occurs
'''
model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 2, num_classes = num_classes)    # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)


model.set_trainable_layers(1,1) # SHOULD NOT LEARN -> AssertionError
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)

'''
Showcase3: same Procedure as in Showcase2 but optimizer State is transferred to recompiled Model:
    Cant set Weigthts because optimizer expects List of Length 0
    
'''

model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 20, num_classes = num_classes)    # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)

opti_state = model.optimizer.get_weights()
model.set_trainable_layers(0,0) # SHOULD NOT LEARN -> Learns
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.optimizer.set_weights(opti_state)
model.fit(ds_train, validation_data = ds_val)
    


#%%%    
'''
Constructing same Architecture with Functional API and running Experiments
'''

import tensorflow as tf    
conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3))
    
inputs = tf.keras.layers.Input((70,70,3))
x = conv_base(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(2048, activation = 'relu') (x)
x = tf.keras.layers.Dense(1024,activation = 'relu') (x)
x = tf.keras.layers.Dense(128,activation = 'relu') (x)
out = tf.keras.layers.Dense(num_classes,activation = 'softmax') (x)
    


isinstance(tf.keras.layers.Flatten(), tf.keras.models.Model)
isinstance(conv_base, tf.keras.models.Model)


def set_trainable_layers(mod, num_head, num_base):
    import time
    for layer in [lay for lay in mod.layers if not isinstance(lay , tf.keras.models.Model)]: 
        layer.trainable = False
        print(layer.name, layer.trainable)
    for block in mod.layers:
        
        if isinstance(block, tf.keras.models.Model):
            print('Found Submodel')
            for layer in block.layers: 
                layer.trainable = False
                print(layer.name, layer.trainable)
            if num_base > 0:    
                for layer in block.layers[-num_base:]:
                    layer.trainable = True
                    print(layer.name, layer.trainable)
    if num_head > 0: 
        for layer in [lay for lay in mod.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
            layer.trainable = True 
            print(layer.name, layer.trainable)
       
    
    
'''
Showcase1: First training frozen Model, then unfreezing, recomiling and retraining:
    model behaves as expected
'''    
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')    
set_trainable_layers(mod, 0 ,0)    
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model should NOT learn


set_trainable_layers(mod, 10,20)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model SHOULD learn


'''
Showcase2: First training unfrozen Model, then reducing number of trainable Layers:
    Model behaves as Expected
'''


mod = tf.keras.models.Model(inputs,out, name = 'TestModel')    
set_trainable_layers(mod, 10 ,20)    
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn


set_trainable_layers(mod, 0,0)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn




'''
Showcase3: First training unfrozen Model, then reducing number of trainable Layers but also trying to trasnfer Optimizer States:
    Behaves as subclassed Model: New Optimizer shouldnt have Weights
'''


mod = tf.keras.models.Model(inputs,out, name = 'TestModel')    
set_trainable_layers(mod, 1 ,3)    
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn

opti_state = mod.optimizer.get_weights()
set_trainable_layers(mod, 4,8)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.optimizer.set_weights(opti_state)
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn

Upvotes: 0

Views: 1657

Answers (2)

Agnosie
Agnosie

Reputation: 65

After some more experiments i have found a workaround for this Problem:

While the model itself cannot be unfrozen/frozen after the first compilation and training, it is however possible to save the model weights to a temporary file model.save_weights('temp.h5') and afterwards reconstructing the model class (Creating a new instance of model class for example) and loading the previous weights with model.load_weights('temp.h5').

However this can also lead to errors occuring when the previous model has both unfrozen and frozen weights. To prevent them you have to either set all layers trainable after the training and before saving weights, or copy the exact trainability structure of the model, and reconstructing the new model such that its layers have the same trainability state as the previous. this is possible with the following functions:

def get_trainability(model): # Takes Keras model and returns dictionary with layer names of Model as key, and its trainability as value/item
    train_dict = {}
    for layer in model.layers:
        if isinstance(layer, tf.keras.models.Model):
            train_dict.update(get_trainability(layer))
        else:    
            train_dict[layer.name] = layer.trainable        
    return train_dict


def set_trainability(model, train_dict): # Takes keras Model and dictionary with layer names and booleans indicating the desired trainability of the layer. 
                                         # modifies model so that every Layer in the Model, whose name matches dict key will get trainable = boolean
    for layer in model.layers: 
        if isinstance(layer, tf.keras.models.Model):
            set_trainability(layer, train_dict)
        else: 
            for name in train_dict.keys():
                if name == layer.name:
                    layer.trainable = train_dict[name]
                    print(layer.name)

Hope this helps for simmilar problems in the Future

Upvotes: 1

IpastorSan
IpastorSan

Reputation: 56

This is happening because one of the fundamental differences between the Subclassing API and the Functional or Sequential APIs in Tensorflow2.

While the Functional or Sequential APIs build a graph of Layers (think of it as a separate data structure), the Subclassing model builds a whole object and stores it as bytecode.

This means that with Subclassing you lose access to the internal connectivity graph and the normal behaviour that allows you to freeze/unfreeze layers or reuse them in other models starts to get weird. Seeing your implementation I would say that the Subclassed model is correct and it SHOULD be working if we were dealing with a library other than Tensorflow that is.

Francois Chollet explains it better than I will ever do in one of his Tweettorials

Upvotes: 1

Related Questions