Reputation: 574
I want to update an index in a 2D tensor with value 0. So data is a 2D tensor whose 2nd row 2nd column index value is to be replaced by 0. However, I am getting a type error. Can anyone help me with it?
TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input
data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])
Upvotes: 4
Views: 7274
Reputation: 606
This is the function I use to modify a subset (row/column) of a 2D tensor in Tensorflow 2:
#note if updatedValue isVector, updatedValue should be provided in 2D format
def modifyTensorRowColumn(a, isRow, index, updatedValue, isVector):
if(not isRow):
a = tf.transpose(a)
if(isVector):
updatedValue = tf.transpose(updatedValue)
if(index == 0):
if(isVector):
values = [updatedValue, a[index+1:]]
else:
values = [[updatedValue], a[index+1:]]
elif(index == a.shape[0]-1):
if(isVector):
values = [a[:index], updatedValue]
else:
values = [a[:index], [updatedValue]]
else:
if(isVector):
values = [a[:index], updatedValue, a[index+1:]]
else:
values = [a[:index], [updatedValue], a[index+1:]]
a = tf.concat(axis=0, values=values)
if(not isRow):
a = tf.transpose(a)
return a
Upvotes: 0
Reputation: 57893
Scatter update only works on variables. Instead try this pattern.
Tensorflow version < 1.0:
a = tf.concat(0, [a[:i], [updated_value], a[i+1:]])
Tensorflow version >= 1.0:
a = tf.concat(axis=0, values=[a[:i], [updated_value], a[i+1:]])
Upvotes: 11
Reputation: 2026
tf.scatter_update
could only be applied to Variable
type. data
in your code IS a Variable
, while data2
IS NOT, because the return type of tf.reshape
is Tensor
.
Solution:
for tensorflow after v1.0
data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0)
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)
for tensorflow before v1.0
data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)
Upvotes: 3