user3446746
user3446746

Reputation: 141

tf.scatter_add causes error in loop

I found very strange behavior of tf.scatter_add: I created a tf.while_loop that creates a Tensor wrapped inside a tf.Variable.

If I don't add something to the Variable outside the loop, tensorflow causes an error telling me that the Variable is not mutable.

Here is a MWE:

import tensorflow as tf        

m = 25
batch_num = 32
num_bus = 50

C = tf.zeros((m, batch_num, num_bus, m),tf.float64)
C = tf.Variable(C)

c = tf.ones((batch_num, num_bus, m), tf.float64)
#C = tf.scatter_add(C,0,c)

k = tf.constant(1)

stop_cond = lambda k,C: k<m

def construct_C(k, C):
    upd_c = c+1
    C = tf.scatter_add(C,k,upd_c)
    return k+1,C

k,C = tf.while_loop(stop_cond,construct_C, (k,C))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
C1 = sess.run(C)

This code causes an error: TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable). However, when I uncomment C = tf.scatter_add(C,0,c) everything works fine.

Is this intended? What am I doing wrong?

Upvotes: 0

Views: 159

Answers (1)

Eugene Brevdo
Eugene Brevdo

Reputation: 899

Sounds like some of the while_loop primitives don't know about Variables (instead, they know about Tensors that are ref type). This looks like a bug in the code - please file an issue on github.

Upvotes: 1

Related Questions