Reputation: 35
Consider the following toy model:
class MyModel(keras.Model):
def __init__(self, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.square_layer = keras.layers.Dense(2)
self.cube_layer = keras.layers.Dense(2)
self.optimizer = tf.keras.optimizers.Adam()
@tf.function
def call(self, X):
return tf.stack([self.square_layer(X), self.cube_layer(X)], axis=-1)
@tf.function
def train_step(self, inputs, targets):
with tf.GradientTape() as tape:
predictions = self(inputs)
loss = tf.reduce_mean(tf.square(predictions - targets))
grads = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return loss
If we train using the following 'train' function, and set 'self.cube_layer.trainable' as True or False, the result is as expected in both the cases:
def train(self, inputs, targets, num_epochs=5000):
self.cube_layer.trainable = False # True or False
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
print("Loss: " +str(loss))
inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)
model = MyModel()
model.train(inputs, targets)
print(model(inputs))
But, if we change the 'trainable' flag during training, the result is not as expected:
def train(self, inputs, targets, num_epochs=5000):
self.cube_layer.trainable = False
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
self.cube_layer.trainable = True
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
print("Loss: " +str(loss))
inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)
model = MyModel()
model.train(inputs, targets)
print(model(inputs))
In the above example, if we remove the '@tf.function' decorators from 'call' and 'train_step', the result is as expected ! So, I believe it has something to do with tf.function and tensorflow graph compilation. Is there a way we can use tf.function and set the 'trainable' attribute dynamically during training ? I am using tensorflow 2.9.1.
Upvotes: 0
Views: 1266
Reputation: 180
I found an other solution in the docs of base_layer.py. Its described in the setter for the trainable attribute:
@trainable.setter
def trainable(self, value):
"""Sets trainable attribute for the layer and its sublayers.
When this value is changed during training (e.g. with a
`tf.keras.callbacks.Callback`) you need to call the parent
`tf.keras.Model.make_train_function` with `force=True` in order to
recompile the training graph.
Args:
value: Boolean with the desired state for the layer's trainable
attribute.
"""
for layer in self._flatten_layers():
layer._trainable = value
Upvotes: 2
Reputation: 902
This is a very intersting and significant problem. Let's locate the problem by adding 3 print
line and do a little test in epoch 5, basing on the last train
func in your question decalration. i.e.:
...
@tf.function
def train_step(self, inputs, targets):
with tf.GradientTape() as tape:
predictions = self(inputs)
loss = tf.reduce_mean(tf.square(predictions - targets))
grads = tape.gradient(loss, self.trainable_variables)
tf.print(len(self.trainable_variables),"in graph") # add
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return loss
...
def train(self, inputs, targets, num_epochs=5):
self.cube_layer.trainable = False
print(len(self.trainable_variables),"before frozen") # add
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
self.cube_layer.trainable = True
print(len(self.trainable_variables),"after frozen") # add
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
output is:
0 before frozen
2 in graph
2 in graph
2 in graph
2 in graph
2 in graph
4 after frozen
2 in graph
2 in graph
2 in graph
2 in graph
2 in graph
Wow~, even you have changed cube_layer
's flag and influence model.trainable_variables
indeed, but did not influence the train_step
.
Because in this code, train_step
has been converted into graph and will not be converted again. It does not mean that once a function is converted into a calculation graph, it will always remain unchanged.
😊The deep reason istf.function
's Tracing
mechanism. If you repeatedly call a Graphed Function with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical. Obviously, here the input of train_step
did not change, so we cannot get a new different Graphed Function, leading invalid modification of self.cube_layer.trainable
.
So, let's fix it. In fact, it's not a bug, because we'd better not mix high-level(compile
,fit
) and medium-level(tf.GradientTape
) APIs. model.compile
only works for model.fit
and did nothing here.
So, a better way here can be write as:
class MyModel(tf.keras.Model):
def __init__(self, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.square_layer = tf.keras.layers.Dense(2)
self.cube_layer = tf.keras.layers.Dense(2)
self.optimizer = tf.keras.optimizers.Adam()
@tf.function
def call(self, X):
return tf.stack([self.square_layer(X), self.cube_layer(X)], axis=-1)
@tf.function
def train_step1(self, inputs,targets,):
with tf.GradientTape() as tape:
predictions = self(inputs)
loss = tf.reduce_mean(tf.square(predictions - targets))
grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return loss
@tf.function
def train_step2(self, inputs,targets):
with tf.GradientTape() as tape:
predictions = self(inputs)
loss = tf.reduce_mean(tf.square(predictions - targets))
grads = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return loss
def train(self, inputs, targets, num_epochs=5000):
self.cube_layer.trainable = False
self.train_step = self.train_step1
for epoch in range(num_epochs):
loss = self.train_step(inputs,targets)
self.cube_layer.trainable = True
self.train_step = self.train_step2
for epoch in range(num_epochs):
loss = self.train_step(inputs,targets)
print("Loss: " +str(loss))
inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)
model = MyModel()
model.train(inputs, targets)
print(model(inputs))
And anything is OK:
Loss: tf.Tensor(1.351493e-06, shape=(), dtype=float32)
tf.Tensor(
[[[ 3. 5.9999933]
[ 8.999994 11.997685 ]]], shape=(1, 2, 2), dtype=float32)
Upvotes: 1