Reputation: 101
I am trying to write my own training loop for TF2/Keras
, following the official Keras walkthrough. The vanilla version works like a charm, but when I try to add the @tf.function
decorator to my training step, some memory leak grabs all my memory and I lose control of my machine, does anyone know what is going on?.
The important parts of the code look like this:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = siamese_network(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, siamese_network.trainable_weights)
optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
val_logits = siamese_network(x, training=False)
val_acc_metric.update_state(y, val_logits)
val_prec_metric.update_state(y_batch_val, val_logits)
val_rec_metric.update_state(y_batch_val, val_logits)
for epoch in range(epochs):
step_time = 0
epoch_time = time.time()
print("Start of {} epoch".format(epoch))
for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
if step > steps_epoch:
break
loss_value = train_step(x_batch_train, y_batch_train)
train_acc = train_acc_metric.result()
train_acc_metric.reset_states()
for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
if val_step>validation_steps:
break
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_prec = val_prec_metric.result()
val_rec = val_rec_metric.result()
val_acc_metric.reset_states()
val_prec_metric.reset_states()
val_rec_metric.reset_states()
If I comment on the @tf.function
lines, the memory leak doesn't occur, but the step time is 3 times slower. My guess is that somehow the graph is bean created again within each epoch or something like that, but I have no idea how to solve it.
This is the tutorial I am following: https://keras.io/guides/writing_a_training_loop_from_scratch/
Upvotes: 10
Views: 3012
Reputation: 9886
TensorFlow may be generating a new graph for each unique set of argument values passed into the decorated functions. Make sure you are passing consistently-shaped Tensor
objects to test_step
and train_step
instead of python objects.
This is a stab in the dark. While I've never tried @tf.function
, I did find the following warnings in the documentation:
tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.
and
Caution: Passing python scalars or lists as arguments to tf.function will always build a new graph. To avoid this, pass numeric arguments as Tensors whenever possible
Finally:
A Function determines whether to reuse a traced ConcreteFunction by computing a cache key from an input's args and kwargs. A cache key is a key that identifies a ConcreteFunction based on the input args and kwargs of the Function call, according to the following rules (which may change):
- The key generated for a tf.Tensor is its shape and dtype.
- The key generated for a tf.Variable is a unique variable id.
- The key generated for a Python primitive (like int, float, str) is its value.
- The key generated for nested dicts, lists, tuples, namedtuples, and attrs is the flattened tuple of leaf-keys (see nest.flatten). (As a result of this flattening, calling a concrete function with a different nesting structure than the one used during tracing will result in a TypeError).
- For all other Python types the key is unique to the object. This way a function or method is traced independently for each instance it is called with.
What I get from all this is that if you don't pass in a consistently-sized Tensor object to your @tf.function
-ified function (perhaps you use Python collections or primitives instead), it is likely that you are creating a new graph version of your function with every distinct argument value you pass in. I'm guessing this could create the memory explosion behavior you're seeing. I can't tell how your test_ds
and train_ds
objects are being created, but you might want to make sure that they are created such that enumerate(blah_ds)
returns tensors like in the tutorial, or at least convert the values to tensors before passing to your test_step
and train_step
functions.
Upvotes: 3