Reputation: 5064
I am trying to translate a NumPy operation of sliced update into TensorFlow. i want to reproduce the following minimal example:
input = np.arange(3 * 5).reshape((3, 5))
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
input[:, [0, 2]] = -1
array([[-1, 1, -1, 3, 4],
[-1, 6, -1, 8, 9],
[-1, 11, -1, 13, 14]])
So, I want to set a constant value to all elements of certain columns in the array.
Now, I have Tensors instead of NumPy arrays, column indices are also computed dynamically and stored in Tensors. I have found how to update all the values in given rows using tf.scatter_nd_update
:
input = tf.Variable(tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5]))
indices = tf.constant([[0], [2]])
updates = tf.constant([[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]])
scatter = tf.scatter_nd_update(input, indices, updates)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(scatter))
Output:
[[-1 -1 -1 -1 -1]
[ 5 6 7 8 9]
[-1 -1 -1 -1 -1]]
But how can I do this for certain columns?
Upvotes: 2
Views: 2646
Reputation: 59721
You can do that like this:
import tensorflow as tf
def update_columns(variable, columns, value):
columns = tf.convert_to_tensor(columns)
rows = tf.range(tf.shape(variable)[0], dtype=columns.dtype)
ii, jj = tf.meshgrid(rows, columns, indexing='ij')
value = tf.broadcast_to(value, tf.shape(ii))
return tf.scatter_nd_update(variable, tf.stack([ii, jj], axis=-1), value)
inp = tf.Variable(tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5]))
updated = update_columns(inp, [0, 2], -1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(updated))
Output:
[[-1 1 -1 3 4]
[-1 6 -1 8 9]
[-1 11 -1 13 14]]
Note however, you should only use tf.scatter_nd_update
if you really want to work with a variable (and assign it a new value). If you want to get a tensor that is equal to another tensor but with some values updated, you should use regular tensor operations instead of converting it into a variable. For example, for this case you could do:
import tensorflow as tf
def update_columns_tensor(tensor, columns, value):
columns = tf.convert_to_tensor(columns)
shape = tf.shape(tensor)
num_rows, num_columns = shape[0], shape[1]
mask = tf.equal(tf.range(num_columns, dtype=columns.dtype), tf.expand_dims(columns, 1))
mask = tf.tile(tf.expand_dims(tf.reduce_any(mask, axis=0), 0), (num_rows, 1))
value = tf.broadcast_to(value, shape)
return tf.where(mask, value, tensor)
inp = tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5])
updated = update_columns_tensor(inp, [0, 2], -1)
with tf.Session() as sess:
print(sess.run(updated))
# Same output
Upvotes: 2