Reputation: 241
I' m following this Manipulating matrix elements in tensorflow. using tf.scatter_update. But my problem is: What happens if my tf.Variable is 2D? Let's say:
a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
How can i update for example the first element of every row and assign to that the value 1?
I tried something like
for line in range(2):
sess.run(tf.scatter_update(a[line],[0],[1]))
but it fails (i was expecting that) and gives me the error:
TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input
How can i fix that kind of problems?
`
Upvotes: 16
Views: 8840
Reputation: 458
I found something here I made a variable name U = [[1, 2, 3], [4, 5, 6]] and wanted to update it like as U[:,1] = [2, 3] so I did U[:,1].assign(cast_into_tensor[2,3])
here a simple code
x = tf.Variable([[1,2,3],[4,5,6]])
print K.eval(x)
y = [0, 0]
with tf.control_dependencies([x[:,1].assign(y)]):
x = tf.identity(x)
print K.eval(x)
Upvotes: 1
Reputation: 48330
In tensorflow you cannot update a Tensor but you can update a Variable.
The scatter_update
operator can update only the first dimension of the variable.
You have to pass always a reference tensor to the scatter update (a
instead of a[line]
).
This is how you can update the first element of the variable:
import tensorflow as tf
g = tf.Graph()
with g.as_default():
a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]])
b = tf.scatter_update(a, [0, 1], [[1, 0, 0, 0], [1, 0, 0, 0]])
with tf.Session(graph=g) as sess:
sess.run(tf.initialize_all_variables())
print sess.run(a)
print sess.run(b)
Output:
[[0 0 0 0]
[0 0 0 0]]
[[1 0 0 0]
[1 0 0 0]]
But having to change again the whole tensor it might be faster to just assign a completely new one.
Upvotes: 14