Reputation: 121
I have an input of dimension (BATCH_SIZE*A*B*FEATURE_LENGTH)
. Now I want to select k(out of B) rows from each of A blocks from each input sample. The k values for each of the A blocks is different.
For eg.
inp = ([[[[ 5, 38, 40, 13, 28],
[12, 6, 36, 20, 23],
[44, 35, 23, 46, 3]],
[[22, 32, 36, 20, 42],
[ 0, 19, 41, 36, 17],
[ 9, 35, 44, 7, 19]],
[[27, 10, 22, 10, 48],
[16, 42, 27, 7, 38],
[35, 32, 15, 39, 28]]]])
#size (1,3,3,5) = (1,A,B,FEATURE_LENGTH)
Now say k=2 i.e I want to extract 2 rows from each of the 3 blocks. I want
row 0 and 1 from 1st block
row 1 and 2 from 2nd block
row 0 and 2 from 3rd block
That means I want my output to look like
([[[[ 5, 38, 40, 13, 28],
[12, 6, 36, 20, 23]],
[[ 0, 19, 41, 36, 17],
[ 9, 35, 44, 7, 19]],
[[27, 10, 22, 10, 48],
[35, 32, 15, 39, 28]]]])
#op shape = (1,3,2,5)
I found that using tf.gather_nd
this is possible if we provide indices as
ind = array([[[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 0], [0, 2, 2]]]])
But if I have input of size (1,16,16,128)
and k=4
, creating this long index sequence will get tedious.
Is there any simpler way to do it in Tensorflow-2?
Thank you!
Upvotes: 0
Views: 768
Reputation: 6377
Use tf.gather()
with batch_dims
argument:
inds = tf.constant([[[0, 1], [1, 2], [0, 2]]])
output = tf.gather(inp, inds, batch_dims=2)
Upvotes: 2