Reputation: 5064
I have two Tensors like this:
template = tf.convert_to_tensor([[1, 0, 0.5, 0.5, 0.3, 0.3],
[1, 0, 0.75, 0.5, 0.3, 0.3],
[1, 0, 0.5, 0.75, 0.3, 0.3],
[1, 0, 0.75, 0.75, 0.3, 0.3]])
patch = tf.convert_to_tensor([[0, 1, 0.43, 0.17, 0.4, 0.4],
[0, 1, 0.18, 0.22, 0.53, 0.6]])
Now I would like to update the second and the last rows of the template
with the patch
rows to get a value like this:
[[1. 0. 0.5 0.5 0.3 0.3 ]
[0. 1. 0.43 0.17 0.4 0.4 ]
[1. 0. 0.5 0.75 0.3 0.3 ]
[0. 1. 0.18 0.22 0.53 0.6 ]]
With tf.scatter_update
it is easy:
var_template = tf.Variable(template)
var_template = tf.scatter_update(var_template, [1, 3], patch)
However, it requires creating a variable. Is there a way to obtain the value using only tensor operations?
I was thinking about tf.where
, but then I probably have to broadcast every patch row into the template size and call tf.where
for each row.
Upvotes: 1
Views: 1707
Reputation: 5064
I will add here also my solution. This utility function works pretty much the same as scatter_update
, but without using Variables:
def scatter_update_tensor(x, indices, updates):
'''
Utility function similar to `tf.scatter_update`, but performing on Tensor
'''
x_shape = tf.shape(x)
patch = tf.scatter_nd(indices, updates, x_shape)
mask = tf.greater(tf.scatter_nd(indices, tf.ones_like(updates), x_shape), 0)
return tf.where(mask, patch, x)
Upvotes: 0
Reputation: 7103
This one should work. A bit twisted, but no variable used.
import tensorflow as tf
template = tf.convert_to_tensor([[1, 1, 0.5, 0.5, 0.3, 0.3],
[2, 2, 0.75, 0.5, 0.3, 0.3],
[3, 3, 0.5, 0.75, 0.3, 0.3],
[4, 4, 0.75, 0.75, 0.3, 0.3]])
patch = tf.convert_to_tensor([[1, 1, 1, 0.17, 0.4, 0.4],
[3, 3, 3, 0.22, 0.53, 0.6]])
ind = tf.constant([1,3])
rn_t = tf.range(0, template.shape[0])
def index1d(t, val):
return tf.reduce_min(tf.where(tf.equal([t], val)))
def index1dd(t,val):
return tf.argmax(tf.cast(tf.equal(t,val), tf.int64), axis=0)
r = tf.map_fn(lambda x: tf.where(tf.equal(index1d(ind, x), 0), patch[index1dd(ind, x)] , template[x]), rn_t, dtype=tf.float32)
with tf.Session() as sess:
print(sess.run([r]))
Upvotes: 3