Kayee Joe
Kayee Joe

Reputation: 43

Use tf.gather to extract tensors row-wise based on another tensor row-wisely (first dimension)

I have two tensors with dimensions as A:[B,3000,3] and C:[B,4000] respectively. I want to use tf.gather() to use every single row from tensor C as index, and to use every row from tensor A as params, to get a result with size [B,4000,3].

Here is an example to make this more understandable: Say I have tensors as

A = [[1,2,3],[4,5,6],[7,8,9]],

C = [0,2,1,2,1],

result = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],

by using tf.gather(A,C). It is all fine when applying to tensors with dimension less than 3.

But when it is the case as the description as the beginning, by applying tf.gather(A,C,axis=1), the shape of result tensor is

[B,B,4000,3]

It seems that tf.gather() just did the job for every element in tensor C as the indices to gather elements in tensor A. The only solution I am thinking about is to use a for loop, but that would extremely reduce the computational ability by using tf.gather(A[i,...],C[i,...]) to gain the correct size of tensor

[B,4000,3]

Thus, is there any function that is able to do this task similarly?

Upvotes: 1

Views: 1462

Answers (1)

javidcf
javidcf

Reputation: 59691

You need to use tf.gather_nd:

import tensorflow as tf

A = ...  # B x 3000 x 3
C = ...  # B x 4000
s = tf.shape(C)
B, cols = s[0], s[1]
# Make indices for first dimension
idx = tf.tile(tf.expand_dims(tf.range(B, dtype=C.dtype), 1), [1, cols])
# Complete index for gather_nd
gather_idx = tf.stack([idx, C], axis=-1)
# Gather result
result = tf.gather_nd(A, gather_idx)

Upvotes: 1

Related Questions