Feri
Feri

Reputation: 1111

Tensorflow 2.0 Autograph indirect modification (hidden states) works, when it shouldn't

So, here it says that indirect modification should not work, which means that changes would be invisible (What does invisible change mean anyway?)

But this code computes the gradient correctly:

import tensorflow as tf


class C:
    def __init__(self):
        self.x = tf.Variable(2.0)

    @tf.function
    def change(self):
        self.x.assign_add(2.0)

    @tf.function
    def func(self):
        self.change()
        return self.x * self.x


c = C()
with tf.GradientTape() as tape:
    y = c.func()
print(tape.gradient(y, c.x)) # --> tf.Tensor(8.0, shape=(), dtype=float32)

Am I missing something here?

Thanks

Upvotes: 2

Views: 515

Answers (1)

Dan Moldovan
Dan Moldovan

Reputation: 975

The docs are missing a detail and should be clarified - "invisible" means the change is not detected by AutoGraph's analyzer. Since AutoGraph analyzes one function at a time, modifications made in another function are not visible to the analyzer.

But, this caveat does not apply to Ops with side effects, such as modifications to TF Variables - those will still be wired correctly in the graph. So your code should work correctly.

The limitation only applies to some changes made to pure Python objects (lists, dicts, etc.), and is only a problem when using control flow.

For example, here's a modification of your code that wouldn't work:

class C:
    def __init__(self):
        self.x = None

    def reset(self):
        self.x = tf.constant(10)

    def change(self):
        self.x += 1

    @tf.function
    def func(self):
      self.reset()
      for i in tf.range(3):
        self.change()
      return self.x * self.x


c = C()
print(c.func())

The error message is rather obscure, but it's the same error that gets raised if you try to access the result of an op created inside the body of a tf.while_loop without using loop_vars:

    <ipython-input-18-23f1641cfa01>:20 func  *
        return self.x * self.x

    ... more internal frames ...

    InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(),
dtype=int32)' cannot be accessed here: it is defined in another function or
code block. Use return values, explicit Python locals or TensorFlow
collections to access it. Defined in: FuncGraph(name=while_body_685,
id=5029696157776); accessed from: FuncGraph(name=func, id=5029690557264).

Upvotes: 2

Related Questions