Reputation: 83
I am puzzled by the behavior I observe in the following example:
import tensorflow as tf
@tf.function
def f(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
def fplain(a):
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
return b, c
a = tf.Variable([[0., 1.], [1., 0.]])
with tf.GradientTape() as tape:
b, c = f(a)
print('tf.function gradient: ', tape.gradient([b], [c]))
# outputs: tf.function gradient: [None]
with tf.GradientTape() as tape:
b, c = fplain(a)
print('plain gradient: ', tape.gradient([b], [c]))
# outputs: plain gradient: [<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
# array([[2., 6.],
# [6., 2.]], dtype=float32)>]
The lower behavior is what I would expect. How can I understand the @tf.function case?
Thank you very much in advance!
(Note that this problem is distinct from: Missing gradient when using tf.function , since here all calculations are inside the function.)
Upvotes: 7
Views: 1389
Reputation: 1641
Gradient tape does not record the operations inside the tf.Graph generated by @tf.function
treating the function as a whole. Roughly, f
is applied to a
, and gradient tape has recorded the gradients of the outputs of f
with respect to input a
(it is the only watched variable, tape.watched_variables()
).
In the second case, there is no graph generated, and operations are applied in Eager mode. So everything works as expected.
A good practice is to wrap a most computationally expensive function in the @tf.function
(often a training loop). In your case, it will be smth like:
@tf.function
def f(a):
with tf.GradientTape() as tape:
c = a * 2
b = tf.reduce_sum(c ** 2 + 2 * c)
grads = tape.gradient([b], [c])
print('tf.function gradient: ', grads)
return grads
Upvotes: 9