Reputation: 2615
How do I change a single value of a Tensor inside of a while loop?
I know that I can manipulate a single value of a tf.Variable
using tf.scatter_update(variable, index, value)
, but inside of a loop I cannot access variables. Is there a way/workaround to manipulate a given value of a Tensor
inside of a while loop.
For reference, here is my current code:
my_variable = tf.Variable()
def body(i, my_variable):
[...]
return tf.add(i, 1), tf.scatter_update(my_variable, [index], value)
loop = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])
Upvotes: 3
Views: 488
Reputation: 2860
Inspired by this post you could use a sparse tensor to store the delta to the value you want to assign and then use addition to "set" that value. E.g. like this (I'm assuming some shapes/values here, but it should be straight-forward to generalize it to tensors of higher rank):
import tensorflow as tf
my_variable = tf.Variable(tf.ones([5]))
def body(i, v):
index = i
new_value = 3.0
delta_value = new_value - v[index:index+1]
delta = tf.SparseTensor([[index]], delta_value, (5,))
v_updated = v + tf.sparse_tensor_to_dense(delta)
return tf.add(i, 1), v_updated
_, updated = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(my_variable))
print(sess.run(updated))
This prints
[1. 1. 1. 1. 1.]
[3. 3. 3. 3. 3.]
Upvotes: 3