anksh
anksh

Reputation: 35

Tensorflow (2.9.1) : Changing the 'trainable' attribute on a layer during training

Consider the following toy model:

class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs)
        
        self.square_layer = keras.layers.Dense(2)
        self.cube_layer = keras.layers.Dense(2)
        
        self.optimizer = tf.keras.optimizers.Adam()
    
    @tf.function
    def call(self, X):
        return tf.stack([self.square_layer(X), self.cube_layer(X)], axis=-1)
    
    @tf.function
    def train_step(self, inputs, targets):
        with tf.GradientTape() as tape:
            predictions = self(inputs)
            loss = tf.reduce_mean(tf.square(predictions - targets))
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return loss

If we train using the following 'train' function, and set 'self.cube_layer.trainable' as True or False, the result is as expected in both the cases:

    def train(self, inputs, targets, num_epochs=5000):
        self.cube_layer.trainable = False  # True or False
        self.compile(optimizer=self.optimizer)
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
            
        print("Loss: " +str(loss))

inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)

model = MyModel()
model.train(inputs, targets)
print(model(inputs))

But, if we change the 'trainable' flag during training, the result is not as expected:

    def train(self, inputs, targets, num_epochs=5000):
        self.cube_layer.trainable = False
        self.compile(optimizer=self.optimizer)
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
        
        self.cube_layer.trainable = True
        self.compile(optimizer=self.optimizer)
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
            
        print("Loss: " +str(loss))

inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)

model = MyModel()
model.train(inputs, targets)
print(model(inputs))

In the above example, if we remove the '@tf.function' decorators from 'call' and 'train_step', the result is as expected ! So, I believe it has something to do with tf.function and tensorflow graph compilation. Is there a way we can use tf.function and set the 'trainable' attribute dynamically during training ? I am using tensorflow 2.9.1.

Upvotes: 0

Views: 1266

Answers (2)

Samuel K.
Samuel K.

Reputation: 180

I found an other solution in the docs of base_layer.py. Its described in the setter for the trainable attribute:

@trainable.setter
def trainable(self, value):
    """Sets trainable attribute for the layer and its sublayers.

    When this value is changed during training (e.g. with a
    `tf.keras.callbacks.Callback`) you need to call the parent
    `tf.keras.Model.make_train_function` with `force=True` in order to
    recompile the training graph.

    Args:
      value: Boolean with the desired state for the layer's trainable
        attribute.
    """
    for layer in self._flatten_layers():
        layer._trainable = value

Upvotes: 2

Little Train
Little Train

Reputation: 902

This is a very intersting and significant problem. Let's locate the problem by adding 3 print line and do a little test in epoch 5, basing on the last train func in your question decalration. i.e.:

...
    @tf.function
    def train_step(self, inputs, targets):
        with tf.GradientTape() as tape:
            predictions = self(inputs)
            loss = tf.reduce_mean(tf.square(predictions - targets))
        grads = tape.gradient(loss, self.trainable_variables)
        tf.print(len(self.trainable_variables),"in graph") # add
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return loss
...
    def train(self, inputs, targets, num_epochs=5):
        self.cube_layer.trainable = False
        print(len(self.trainable_variables),"before frozen") # add
        self.compile(optimizer=self.optimizer)
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
        self.cube_layer.trainable = True
        print(len(self.trainable_variables),"after frozen") # add
        self.compile(optimizer=self.optimizer)
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)

output is:

0 before frozen
2 in graph
2 in graph
2 in graph
2 in graph
2 in graph
4 after frozen
2 in graph
2 in graph
2 in graph
2 in graph
2 in graph

Wow~, even you have changed cube_layer's flag and influence model.trainable_variables indeed, but did not influence the train_step. Because in this code, train_step has been converted into graph and will not be converted again. It does not mean that once a function is converted into a calculation graph, it will always remain unchanged.

😊The deep reason istf.function's Tracing mechanism. If you repeatedly call a Graphed Function with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical. Obviously, here the input of train_step did not change, so we cannot get a new different Graphed Function, leading invalid modification of self.cube_layer.trainable.

So, let's fix it. In fact, it's not a bug, because we'd better not mix high-level(compile,fit) and medium-level(tf.GradientTape) APIs. model.compileonly works for model.fit and did nothing here. So, a better way here can be write as:


class MyModel(tf.keras.Model):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs)
        
        self.square_layer = tf.keras.layers.Dense(2)
        self.cube_layer = tf.keras.layers.Dense(2)
        
        self.optimizer = tf.keras.optimizers.Adam()
    @tf.function
    def call(self, X):
        return tf.stack([self.square_layer(X), self.cube_layer(X)], axis=-1)
    @tf.function
    def train_step1(self, inputs,targets,):
        with tf.GradientTape() as tape:
            predictions = self(inputs)
            loss = tf.reduce_mean(tf.square(predictions - targets))
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return loss
    @tf.function
    def train_step2(self, inputs,targets):
        with tf.GradientTape() as tape:
            predictions = self(inputs)
            loss = tf.reduce_mean(tf.square(predictions - targets))
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return loss
    def train(self, inputs, targets, num_epochs=5000):
        self.cube_layer.trainable = False
        self.train_step = self.train_step1
        for epoch in range(num_epochs):
            loss = self.train_step(inputs,targets)
        self.cube_layer.trainable = True
        self.train_step = self.train_step2
        for epoch in range(num_epochs):
            loss = self.train_step(inputs,targets)
        print("Loss: " +str(loss))
inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)

model = MyModel()
model.train(inputs, targets)
print(model(inputs))

And anything is OK:

Loss: tf.Tensor(1.351493e-06, shape=(), dtype=float32)
tf.Tensor(
[[[ 3.         5.9999933]
  [ 8.999994  11.997685 ]]], shape=(1, 2, 2), dtype=float32)

Upvotes: 1

Related Questions