Reputation: 141
I found very strange behavior of tf.scatter_add: I created a tf.while_loop that creates a Tensor wrapped inside a tf.Variable.
If I don't add something to the Variable outside the loop, tensorflow causes an error telling me that the Variable is not mutable.
Here is a MWE:
import tensorflow as tf
m = 25
batch_num = 32
num_bus = 50
C = tf.zeros((m, batch_num, num_bus, m),tf.float64)
C = tf.Variable(C)
c = tf.ones((batch_num, num_bus, m), tf.float64)
#C = tf.scatter_add(C,0,c)
k = tf.constant(1)
stop_cond = lambda k,C: k<m
def construct_C(k, C):
upd_c = c+1
C = tf.scatter_add(C,k,upd_c)
return k+1,C
k,C = tf.while_loop(stop_cond,construct_C, (k,C))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
C1 = sess.run(C)
This code causes an error: TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)
. However, when I uncomment C = tf.scatter_add(C,0,c)
everything works fine.
Is this intended? What am I doing wrong?
Upvotes: 0
Views: 159
Reputation: 899
Sounds like some of the while_loop primitives don't know about Variables (instead, they know about Tensors that are ref type). This looks like a bug in the code - please file an issue on github.
Upvotes: 1