adng
adng

Reputation: 53

How to use GradientTape with AutoGraph in Tensorflow 2?

I can not figure out how to run GradientTape code on AutoGraph in Tensorflow 2.

I want to run GradientTape code on TPU. I wanted to start by testing it on CPU. TPU code would run much faster using AutoGraph. I tried watching the input variable and I tried passing in the argument into the function that contained GradientTape but both failed.

I made a reproducible example here: https://colab.research.google.com/drive/1luCk7t5SOcDHQC6YTzJzpSixIqHy_76b#scrollTo=9OQUYQtTTYIt

The code and the corresponding output is as follows: They all start with import tensorflow as tf

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x)
print(dy_dx)

Output: tf.Tensor(6.0, shape=(), dtype=float32) Explanation: Using Eager Execution, the GradientTape produces the gradient.

@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
      g.watch(x)
      y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    print(dy_dx)
compute_me()

Output: Tensor("AddN:0", shape=(), dtype=float32) Explanation: Using AutoGraph on GradientTape in TF2 results in empty gradient

@tf.function
def compute_me_args(x):
    with tf.GradientTape() as g:
      g.watch(x)
      y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    print(dy_dx)    
x = tf.constant(3.0)
compute_me_args(x)

Output: Tensor("AddN:0", shape=(), dtype=float32) Explanation: Passing in arguments also fails

I was expecting all of the cells to output tf.Tensor(6.0, shape=(), dtype=float32) but instead the cells using AutoGraph output Tensor("AddN:0", shape=(), dtype=float32).

Upvotes: 2

Views: 611

Answers (1)

xdurch0
xdurch0

Reputation: 10474

It doesn't "fail", it's just that print, if used in the context of a tf.function (i.e. in graph mode) will print the symbolic tensors, and these do not have a value. Try this instead:

@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
        g.watch(x)
        y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    tf.print(dy_dx)
compute_me()

This should print 6. All you need to do is use tf.print instead, which is "smart" enough to print the actual values if available. Or, using return values:

@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
        g.watch(x)
        y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    return dy_dx
result = compute_me()
print(result)

outputs something like <tf.Tensor: id=43, shape=(), dtype=float32, numpy=6.0>. You can see that the value (6.0) is visible here as well. Use result.numpy() to just get 6.0.

Upvotes: 1

Related Questions