Reputation: 43
I have two tensors with dimensions as A:[B,3000,3]
and C:[B,4000]
respectively. I want to use tf.gather()
to use every single row from tensor C as index, and to use every row from tensor A as params, to get a result with size [B,4000,3]
.
Here is an example to make this more understandable: Say I have tensors as
A = [[1,2,3],[4,5,6],[7,8,9]],
C = [0,2,1,2,1],
result = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],
by using tf.gather(A,C)
. It is all fine when applying to tensors with dimension less than 3.
But when it is the case as the description as the beginning, by applying tf.gather(A,C,axis=1)
, the shape of result tensor is
[B,B,4000,3]
It seems that tf.gather()
just did the job for every element in tensor C as the indices to gather elements in tensor A. The only solution I am thinking about is to use a for
loop, but that would extremely reduce the computational ability by using tf.gather(A[i,...],C[i,...])
to gain the correct size of tensor
[B,4000,3]
Thus, is there any function that is able to do this task similarly?
Upvotes: 1
Views: 1462
Reputation: 59691
You need to use tf.gather_nd
:
import tensorflow as tf
A = ... # B x 3000 x 3
C = ... # B x 4000
s = tf.shape(C)
B, cols = s[0], s[1]
# Make indices for first dimension
idx = tf.tile(tf.expand_dims(tf.range(B, dtype=C.dtype), 1), [1, cols])
# Complete index for gather_nd
gather_idx = tf.stack([idx, C], axis=-1)
# Gather result
result = tf.gather_nd(A, gather_idx)
Upvotes: 1