Reputation: 53
I can not figure out how to run GradientTape code on AutoGraph in Tensorflow 2.
I want to run GradientTape code on TPU. I wanted to start by testing it on CPU. TPU code would run much faster using AutoGraph. I tried watching the input variable and I tried passing in the argument into the function that contained GradientTape but both failed.
I made a reproducible example here: https://colab.research.google.com/drive/1luCk7t5SOcDHQC6YTzJzpSixIqHy_76b#scrollTo=9OQUYQtTTYIt
The code and the corresponding output is as follows:
They all start with import tensorflow as tf
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x)
print(dy_dx)
Output: tf.Tensor(6.0, shape=(), dtype=float32)
Explanation: Using Eager Execution, the GradientTape produces the gradient.
@tf.function
def compute_me():
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
print(dy_dx)
compute_me()
Output: Tensor("AddN:0", shape=(), dtype=float32)
Explanation: Using AutoGraph on GradientTape in TF2 results in empty gradient
@tf.function
def compute_me_args(x):
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
print(dy_dx)
x = tf.constant(3.0)
compute_me_args(x)
Output: Tensor("AddN:0", shape=(), dtype=float32)
Explanation: Passing in arguments also fails
I was expecting all of the cells to output tf.Tensor(6.0, shape=(), dtype=float32)
but instead the cells using AutoGraph output Tensor("AddN:0", shape=(), dtype=float32)
.
Upvotes: 2
Views: 611
Reputation: 10474
It doesn't "fail", it's just that print
, if used in the context of a tf.function
(i.e. in graph mode) will print the symbolic tensors, and these do not have a value. Try this instead:
@tf.function
def compute_me():
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
tf.print(dy_dx)
compute_me()
This should print 6
. All you need to do is use tf.print
instead, which is "smart" enough to print the actual values if available. Or, using return values:
@tf.function
def compute_me():
x = tf.constant(3.0)
with tf.GradientTape() as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x) # Will compute to 6.0
return dy_dx
result = compute_me()
print(result)
outputs something like <tf.Tensor: id=43, shape=(), dtype=float32, numpy=6.0>
. You can see that the value (6.0) is visible here as well. Use result.numpy()
to just get 6.0
.
Upvotes: 1