Reputation: 1111
So, here it says that indirect modification should not work, which means that changes would be invisible (What does invisible change mean anyway?)
But this code computes the gradient correctly:
import tensorflow as tf
class C:
def __init__(self):
self.x = tf.Variable(2.0)
@tf.function
def change(self):
self.x.assign_add(2.0)
@tf.function
def func(self):
self.change()
return self.x * self.x
c = C()
with tf.GradientTape() as tape:
y = c.func()
print(tape.gradient(y, c.x)) # --> tf.Tensor(8.0, shape=(), dtype=float32)
Am I missing something here?
Thanks
Upvotes: 2
Views: 515
Reputation: 975
The docs are missing a detail and should be clarified - "invisible" means the change is not detected by AutoGraph's analyzer. Since AutoGraph analyzes one function at a time, modifications made in another function are not visible to the analyzer.
But, this caveat does not apply to Ops with side effects, such as modifications to TF Variables - those will still be wired correctly in the graph. So your code should work correctly.
The limitation only applies to some changes made to pure Python objects (lists, dicts, etc.), and is only a problem when using control flow.
For example, here's a modification of your code that wouldn't work:
class C:
def __init__(self):
self.x = None
def reset(self):
self.x = tf.constant(10)
def change(self):
self.x += 1
@tf.function
def func(self):
self.reset()
for i in tf.range(3):
self.change()
return self.x * self.x
c = C()
print(c.func())
The error message is rather obscure, but it's the same error that gets raised if you try to access the result of an op created inside the body of a tf.while_loop
without using loop_vars
:
<ipython-input-18-23f1641cfa01>:20 func *
return self.x * self.x
... more internal frames ...
InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(),
dtype=int32)' cannot be accessed here: it is defined in another function or
code block. Use return values, explicit Python locals or TensorFlow
collections to access it. Defined in: FuncGraph(name=while_body_685,
id=5029696157776); accessed from: FuncGraph(name=func, id=5029690557264).
Upvotes: 2