Reputation: 65
I am trying to build a subclassed Model which consists of a pretrained convolutional Base and some Dense Layers on top, using Tensorflow >= 2.4. However freezing/unfreezing of the subclassed Model has no effect once it was trained before. When I do the same with the Functional API everything works as expected. I would really appreciate some Hint to what im missing here: Following Code should specify my problem further. Pardon me the amount of Code:
#Setup
import tensorflow as tf
tf.config.run_functions_eagerly(False)
import numpy as np
from tensorflow.keras.regularizers import l1
import matplotlib.pyplot as plt
@tf.function
def create_images_and_labels(img,label, height = 70, width = 70): #Image augmentation
label = tf.cast(label, 'float32')
label = tf.squeeze(label)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, (height, width))
# img = preprocess_input(img)
return img, label
cifar = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar.load_data()
num_classes = len(np.unique(y_train))
ds_train = tf.data.Dataset.from_tensor_slices((x_train, tf.one_hot(y_train, depth = len(np.unique(y_train)))))
ds_train = ds_train.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))
ds_train = ds_train.shuffle(50000)
ds_train = ds_train.batch(50, drop_remainder = True)
ds_val = tf.data.Dataset.from_tensor_slices((x_test, tf.one_hot(y_test, depth = len(np.unique(y_train)))))
ds_val = ds_val.map(lambda img, label: create_images_and_labels(img, label, height = 70, width = 70))
ds_val = ds_val.batch(50, drop_remainder=True)
# for i in ds_train.take(1):
# x, y = i
# for ind in range(x.shape[0]):
# plt.imshow(x[ind,:,:])
# plt.show()
# print(y[ind])
'''
Defining simple subclassed Model consisting of
VGG16
Flatten
Dense Layers
customized what happens in model.fit and model.evaluate (Actually its the standard Keras procedure with custom Metrics)
customized metrics: Loss and Accuracy for Training and Validation Step
added unfreezing Method
'set_trainable_layers'
Arguments:
num_head (How many dense Layers)
num_base (How many VGG Layers)
'''
class Test_Model(tf.keras.models.Model):
def __init__(
self,
num_unfrozen_head_layers,
num_unfrozen_base_layers,
num_classes,
conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3)),
):
super(Test_Model, self).__init__(name = "Test_Model")
self.base = conv_base
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(2048, activation = 'relu')
self.dense2 = tf.keras.layers.Dense(1024, activation = 'relu')
self.dense3 = tf.keras.layers.Dense(128, activation = 'relu')
self.out = tf.keras.layers.Dense(num_classes, activation = 'softmax')
self.out._name = 'out'
self.train_loss_metric = tf.keras.metrics.Mean('Supervised Training Loss')
self.train_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Training Accuracy')
self.val_loss_metric = tf.keras.metrics.Mean('Supervised Validation Loss')
self.val_acc_metric = tf.keras.metrics.CategoricalAccuracy('Supervised Validation Accuracy')
self.loss_fn = tf.keras.losses.categorical_crossentropy
self.learning_rate = 1e-4
# self.build((None, 32,32,3))
self.set_trainable_layers(num_unfrozen_head_layers, num_unfrozen_base_layers)
@tf.function
def call(self, inputs, training = False):
x = self.base(inputs)
x = self.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
x = self.out(x)
return x
@tf.function
def train_step(self, input_data):
x_batch, y_batch = input_data
with tf.GradientTape() as tape:
tape.watch(x_batch)
y_pred = self(x_batch, training = True)
loss = self.loss_fn(y_batch, y_pred)
trainable_vars = self.trainable_weights
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.train_loss_metric.update_state(loss)
self.train_acc_metric.update_state(y_batch, y_pred)
return {"Supervised Loss": self.train_loss_metric.result(),
"Supervised Accuracy":self.train_acc_metric.result()}
@tf.function
def test_step(self, input_data):
x_batch,y_batch = input_data
y_pred = self(x_batch, training = False)
loss = self.loss_fn(y_batch, y_pred)
self.val_loss_metric.update_state(loss)
self.val_acc_metric.update_state(y_batch, y_pred)
return {"Val Supervised Loss": self.val_loss_metric.result(),
"Val Supervised Accuracy":self.val_acc_metric.result()}
@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
# If you don't implement this property, you have to call
# `reset_states()` yourself at the time of your choosing.
return [self.train_loss_metric,
self.train_acc_metric,
self.val_loss_metric,
self.val_acc_metric]
def set_trainable_layers(self, num_head, num_base):
for layer in [lay for lay in self.layers if not isinstance(lay , tf.keras.models.Model)]:
layer.trainable = False
print(layer.name, layer.trainable)
for block in self.layers:
if isinstance(block, tf.keras.models.Model):
print('Found Submodel', block.name)
for layer in block.layers:
layer.trainable = False
print(layer.name, layer.trainable)
if num_base > 0:
for layer in block.layers[-num_base:]:
layer.trainable = True
print(layer.name, layer.trainable)
if num_head > 0:
for layer in [lay for lay in self.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
layer.trainable = True
print(layer.name, layer.trainable)
'''
Showcase1: First training completely frozen Model, then unfreezing:
unfreezed model doesnt learn
'''
model = Test_Model(num_unfrozen_head_layers= 0, num_unfrozen_base_layers = 0, num_classes = num_classes) # Should NOT learn -> doesnt learn
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
model.set_trainable_layers(10,20) # SHOULD LEARN -> Doesnt learn
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
#DOESNT LEARN
'''
Showcase2: when first training the Model with more trainable Layers than in the second step:
AssertionError occurs
'''
model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 2, num_classes = num_classes) # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
model.set_trainable_layers(1,1) # SHOULD NOT LEARN -> AssertionError
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
'''
Showcase3: same Procedure as in Showcase2 but optimizer State is transferred to recompiled Model:
Cant set Weigthts because optimizer expects List of Length 0
'''
model = Test_Model(num_unfrozen_head_layers= 10, num_unfrozen_base_layers = 20, num_classes = num_classes) # SHOULD LEARN -> learns
model.build((None, 70,70,3))
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.fit(ds_train, validation_data = ds_val)
opti_state = model.optimizer.get_weights()
model.set_trainable_layers(0,0) # SHOULD NOT LEARN -> Learns
model.summary()
model.compile(optimizer = tf.keras.optimizers.Adam(1e-5))
model.optimizer.set_weights(opti_state)
model.fit(ds_train, validation_data = ds_val)
#%%%
'''
Constructing same Architecture with Functional API and running Experiments
'''
import tensorflow as tf
conv_base = tf.keras.applications.VGG16(include_top = False, weights = 'imagenet', input_shape = (70,70,3))
inputs = tf.keras.layers.Input((70,70,3))
x = conv_base(inputs)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(2048, activation = 'relu') (x)
x = tf.keras.layers.Dense(1024,activation = 'relu') (x)
x = tf.keras.layers.Dense(128,activation = 'relu') (x)
out = tf.keras.layers.Dense(num_classes,activation = 'softmax') (x)
isinstance(tf.keras.layers.Flatten(), tf.keras.models.Model)
isinstance(conv_base, tf.keras.models.Model)
def set_trainable_layers(mod, num_head, num_base):
import time
for layer in [lay for lay in mod.layers if not isinstance(lay , tf.keras.models.Model)]:
layer.trainable = False
print(layer.name, layer.trainable)
for block in mod.layers:
if isinstance(block, tf.keras.models.Model):
print('Found Submodel')
for layer in block.layers:
layer.trainable = False
print(layer.name, layer.trainable)
if num_base > 0:
for layer in block.layers[-num_base:]:
layer.trainable = True
print(layer.name, layer.trainable)
if num_head > 0:
for layer in [lay for lay in mod.layers if not isinstance(lay, tf.keras.models.Model)][-num_head:]:
layer.trainable = True
print(layer.name, layer.trainable)
'''
Showcase1: First training frozen Model, then unfreezing, recomiling and retraining:
model behaves as expected
'''
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')
set_trainable_layers(mod, 0 ,0)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model should NOT learn
set_trainable_layers(mod, 10,20)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model SHOULD learn
'''
Showcase2: First training unfrozen Model, then reducing number of trainable Layers:
Model behaves as Expected
'''
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')
set_trainable_layers(mod, 10 ,20)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn
set_trainable_layers(mod, 0,0)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn
'''
Showcase3: First training unfrozen Model, then reducing number of trainable Layers but also trying to trasnfer Optimizer States:
Behaves as subclassed Model: New Optimizer shouldnt have Weights
'''
mod = tf.keras.models.Model(inputs,out, name = 'TestModel')
set_trainable_layers(mod, 1 ,3)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.fit(ds_train, validation_data = ds_val) # Model SHOULD learn
opti_state = mod.optimizer.get_weights()
set_trainable_layers(mod, 4,8)
mod.summary()
mod.compile(optimizer = tf.keras.optimizers.Adam(1e-5), loss = 'categorical_crossentropy', metrics = ['accuracy'])
mod.optimizer.set_weights(opti_state)
mod.fit(ds_train, validation_data = ds_val) #Model should NOT learn
Upvotes: 0
Views: 1657
Reputation: 65
After some more experiments i have found a workaround for this Problem:
While the model itself cannot be unfrozen/frozen after the first compilation and training, it is however possible to save the model weights to a temporary file model.save_weights('temp.h5')
and afterwards reconstructing the model class (Creating a new instance of model class for example) and loading the previous weights with model.load_weights('temp.h5')
.
However this can also lead to errors occuring when the previous model has both unfrozen and frozen weights. To prevent them you have to either set all layers trainable after the training and before saving weights, or copy the exact trainability structure of the model, and reconstructing the new model such that its layers have the same trainability state as the previous. this is possible with the following functions:
def get_trainability(model): # Takes Keras model and returns dictionary with layer names of Model as key, and its trainability as value/item
train_dict = {}
for layer in model.layers:
if isinstance(layer, tf.keras.models.Model):
train_dict.update(get_trainability(layer))
else:
train_dict[layer.name] = layer.trainable
return train_dict
def set_trainability(model, train_dict): # Takes keras Model and dictionary with layer names and booleans indicating the desired trainability of the layer.
# modifies model so that every Layer in the Model, whose name matches dict key will get trainable = boolean
for layer in model.layers:
if isinstance(layer, tf.keras.models.Model):
set_trainability(layer, train_dict)
else:
for name in train_dict.keys():
if name == layer.name:
layer.trainable = train_dict[name]
print(layer.name)
Hope this helps for simmilar problems in the Future
Upvotes: 1
Reputation: 56
This is happening because one of the fundamental differences between the Subclassing API and the Functional or Sequential APIs in Tensorflow2.
While the Functional or Sequential APIs build a graph of Layers (think of it as a separate data structure), the Subclassing model builds a whole object and stores it as bytecode.
This means that with Subclassing you lose access to the internal connectivity graph and the normal behaviour that allows you to freeze/unfreeze layers or reuse them in other models starts to get weird. Seeing your implementation I would say that the Subclassed model is correct and it SHOULD be working if we were dealing with a library other than Tensorflow that is.
Francois Chollet explains it better than I will ever do in one of his Tweettorials
Upvotes: 1