Reputation: 6874
The two code snippets show use of the same data to do n updates where one uses a persistent gradient tape and the other just calls it over and over again. The perf difference seems to be about 2x. Is there a better way to structure this going forward? I suppose moving data to device would matter on GPU?
@tf.function
def train_n_steps_same_data(tau, y, n=1):
"""
In [218]: %timeit r = q.train_n_steps_same_data(q.tau, q.y, n=100)
25.3 ms ± 926 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
"""
with tf.GradientTape(persistent=True) as tape:
d = model([tau, y])
loss = tf.reduce_mean(d['rho'])
for i in range(n):
gradients = tape.gradient(loss, model.trainable_variables)
l = optimizer.apply_gradients(zip(gradients, model.trainable_variables))
names = [x.name for x in gradients]
g = dict(zip(names, gradients))
reduced = dict()
reduced['loss'] = loss
return reduced, d, g
@tf.function
def train_n_steps_same_data2(tau, y, n=1):
"""
In [220]: %timeit r = q.train_n_steps_same_data2(q.tau, q.y, n=100)
41.6 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
"""
for i in range(n):
with tf.GradientTape() as tape:
d = model([tau, y])
loss = tf.reduce_mean(d['rho'])
gradients = tape.gradient(loss, model.trainable_variables)
l = optimizer.apply_gradients(zip(gradients, model.trainable_variables))
names = [x.name for x in gradients]
g = dict(zip(names, gradients))
reduced = dict()
reduced['loss'] = loss
return reduced, d, g
Upvotes: 1
Views: 629
Reputation: 27070
The first approach is better for sure. You create a single tape object and reuse it inside a loop. The second function, instead, creates and destroys a tape object on every iteration.
However, you're missing a very important part in the first training loop: your tape is persistent. Therefore after you used it you have to manually delete it using del tape
, otherwise, you're causing a memory leak.
Another suggestion is to do not use range
when you decorate a function with tf.function
, but to use instead tf.range
(and in general, wherever possible, use the tf.*
method equivalent for the Python construct, see this article)
Upvotes: 2