Reputation: 81
I am looking for a tf operation that replicates the elements in a input tensor x by y[i] times, where i is the index in a second tensor. More precisely, the operation should achieve the following:
x = tf.constant([[1, 4], [2, 5], [3, 6]])
y = tf.constant([3, 2, 4])
z = <operation>(x, y) # [[1, 4], [1, 4], [1, 4],
[2, 5], [2, 5],
[3, 6], [3, 6], [3, 6], [3, 6]]
What operation can I use? Thanks :)
Upvotes: 0
Views: 35
Reputation: 2679
The key idea is to build a 1-D tensor of indices replicated according to y
and then do a tf.gather
:
def repeat(t, times):
num_elements = tf.shape(t)[0]
def cond_fn(i, _):
return i < num_elements
def body_fn(i, indices_ta):
repeated_i = tf.tile(i[tf.newaxis], times[i, tf.newaxis])
return (i + 1, indices_ta.write(i, repeated_i))
indices_ta = tf.TensorArray(times.dtype, num_elements, infer_shape=False)
_, indices_ta = tf.while_loop(
cond_fn,
body_fn,
loop_vars=(0, indices_ta))
return tf.gather(t, indices_ta.concat())
Upvotes: 1