Reputation: 2064
I want tensorflow to do the following in f(...)
But tf.control_dependencies
doesn't do what I want.
How to fix the control dependency?
Result:
cache_ 0.0
x_ 2.0
AssertionError
Test:
import tensorflow as tf
import numpy as np
def f(a, cache):
assign_op = tf.assign(cache, a)
with tf.control_dependencies([assign_op]):
return a
def main():
dtype = np.float32
data = tf.range(5, dtype=dtype)
cache = tf.Variable(0, dtype=dtype)
x = f(data[2], cache)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
x_ = sess.run(x)
cache_ = sess.run(cache)
print("cache_", cache_)
print("x_", x_)
assert np.allclose(cache_, x_)
main()
Upvotes: 0
Views: 58
Reputation: 1856
The problem is that return a
is Python code. You are not creating any TensorFlow ops in the with
block. You can use tf.identity
to create an op that will ensure that when a
is read from assign_op
will be executed first. Here is the updated code:
def f(a, cache):
assign_op = tf.assign(cache, a)
with tf.control_dependencies([assign_op]):
return tf.identity(a)
Upvotes: 2