safetyduck
safetyduck

Reputation: 6874

How to structure loops for gradient updates when data is not changing in tensorflow 2.0?

The two code snippets show use of the same data to do n updates where one uses a persistent gradient tape and the other just calls it over and over again. The perf difference seems to be about 2x. Is there a better way to structure this going forward? I suppose moving data to device would matter on GPU?

@tf.function
def train_n_steps_same_data(tau, y, n=1):
    """
    In [218]: %timeit r = q.train_n_steps_same_data(q.tau, q.y, n=100)
    25.3 ms ± 926 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    """
    with tf.GradientTape(persistent=True) as tape:
        d = model([tau, y])
        loss = tf.reduce_mean(d['rho'])
    for i in range(n):
        gradients = tape.gradient(loss, model.trainable_variables)
        l = optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    names = [x.name for x in gradients]
    g = dict(zip(names, gradients))
    reduced = dict()
    reduced['loss'] = loss
    return reduced, d, g

@tf.function
def train_n_steps_same_data2(tau, y, n=1):
    """
    In [220]: %timeit r = q.train_n_steps_same_data2(q.tau, q.y, n=100)
    41.6 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    """
    for i in range(n):
        with tf.GradientTape() as tape:
            d = model([tau, y])
            loss = tf.reduce_mean(d['rho'])
        gradients = tape.gradient(loss, model.trainable_variables)
        l = optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    names = [x.name for x in gradients]
    g = dict(zip(names, gradients))
    reduced = dict()
    reduced['loss'] = loss
    return reduced, d, g

Upvotes: 1

Views: 629

Answers (1)

nessuno
nessuno

Reputation: 27070

The first approach is better for sure. You create a single tape object and reuse it inside a loop. The second function, instead, creates and destroys a tape object on every iteration.

However, you're missing a very important part in the first training loop: your tape is persistent. Therefore after you used it you have to manually delete it using del tape, otherwise, you're causing a memory leak.

Another suggestion is to do not use range when you decorate a function with tf.function, but to use instead tf.range (and in general, wherever possible, use the tf.* method equivalent for the Python construct, see this article)

Upvotes: 2

Related Questions