marlon
marlon

Reputation: 83

Taking gradients when using tf.function

I am puzzled by the behavior I observe in the following example:

import tensorflow as tf

@tf.function
def f(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c

def fplain(a):
    c = a * 2
    b = tf.reduce_sum(c ** 2 + 2 * c)
    return b, c


a = tf.Variable([[0., 1.], [1., 0.]])

with tf.GradientTape() as tape:
    b, c = f(a)
    
print('tf.function gradient: ', tape.gradient([b], [c]))

# outputs: tf.function gradient:  [None]

with tf.GradientTape() as tape:
    b, c = fplain(a)
    
print('plain gradient: ', tape.gradient([b], [c]))

# outputs: plain gradient:  [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
#        [6., 2.]], dtype=float32)>]

The lower behavior is what I would expect. How can I understand the @tf.function case?

Thank you very much in advance!

(Note that this problem is distinct from: Missing gradient when using tf.function , since here all calculations are inside the function.)

Upvotes: 7

Views: 1389

Answers (1)

user1635327
user1635327

Reputation: 1641

Gradient tape does not record the operations inside the tf.Graph generated by @tf.function treating the function as a whole. Roughly, f is applied to a, and gradient tape has recorded the gradients of the outputs of f with respect to input a (it is the only watched variable, tape.watched_variables()).

In the second case, there is no graph generated, and operations are applied in Eager mode. So everything works as expected.

A good practice is to wrap a most computationally expensive function in the @tf.function (often a training loop). In your case, it will be smth like:

@tf.function
def f(a):
    with tf.GradientTape() as tape:
        c = a * 2
        b = tf.reduce_sum(c ** 2 + 2 * c)
    grads = tape.gradient([b], [c])
    print('tf.function gradient: ', grads)
    return grads

Upvotes: 9

Related Questions