Felipe Moser
Felipe Moser

Reputation: 365

How to use tf.gather in batch?

I have a A = 10x1000 tensor and a B = 10x1000 index tensor. The tensor B has values between 0-999 and it's used to gather values from A (B[0,:] gathers from A[0,:], B[1,:] from A[1,:], etc...).

However, if I use tf.gather(A, B) I get an array of shape (10, 1000, 1000) when I'm expecting a 10x1000 tensor back. Any ideas how I could fix this?

EDIT

Let's say A= [[1, 2, 3],[4,5,6]] and B = [[0, 1, 1],[2,1,0]] What I want is to be able to sample A using the corresponding B. This should result in C = [[1, 2, 2],[6,5,4]].

Upvotes: 3

Views: 4876

Answers (2)

Pengcheng Fan
Pengcheng Fan

Reputation: 121

Use tf.gather with batch_dims=-1:

import numpy as np
import tensorflow as tf

rois = np.array([[1, 2, 3],[3, 2, 1]])
ind = np.array([[0, 2, 1, 1, 2, 0, 0, 1, 1, 2],
        [0, 1, 2, 0, 2, 0, 1, 2, 2, 2]])
tf.gather(rois, ind, batch_dims=-1)
# output:
# <tf.Tensor: shape=(2, 10), dtype=int64, numpy=
# array([[1, 3, 2, 2, 3, 1, 1, 2, 2, 3],
#       [3, 2, 1, 3, 1, 3, 2, 1, 1, 1]])>

Upvotes: 2

Vlad
Vlad

Reputation: 8605

  1. Dimensions of tensors are known in advance.

First we 'unstack' both the parameters and indices (A and B respectively) along the first dimension. Then we apply tf.gather() such that rows of A correspond to the rows of B. Finally, we stack together the result.

import tensorflow as tf
import numpy as np

def custom_gather(a, b):
    unstacked_a = tf.unstack(a, axis=0)
    unstacked_b = tf.unstack(b, axis=0)
    gathered = [tf.gather(x, y) for x, y in zip(unstacked_a, unstacked_b)]
    return tf.stack(gathered, axis=0)

a = tf.convert_to_tensor(np.array([[1, 2, 3], [4, 5, 6]]), tf.float32)
b = tf.convert_to_tensor(np.array([[0, 1, 1], [2, 1, 0]]), dtype=tf.int32)

gathered = custom_gather(a, b)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gathered))
# [[1. 2. 2.]
#  [6. 5. 4.]]

For you initial case with shapes 1000x10 we get:

a = tf.convert_to_tensor(np.random.normal(size=(10, 1000)), tf.float32)
b = tf.convert_to_tensor(np.random.randint(low=0, high=999, size=(10, 1000)), dtype=tf.int32)
gathered = custom_gather(a, b)
print(gathered.get_shape().as_list()) # [10, 1000]

Update

  1. The first dimension is unknown (i.e. None)

The previous solution works only if the first dimension is known in advance. If the dimension is unknown we solve it as follows:

  • We stack together two tensors such that the rows of both tensors are stacked together:
# A = [[1, 2, 3], [4, 5, 6]]        [[[1 2 3]
#                            --->     [0 1 1]]
#                                    [[4 5 6]
# B = [[0, 1, 1], [2, 1, 0]]          [2 1 0]]]
  • We iterate over the elements of this stacked tensor (which consists of stacked together rows of A and B) and using tf.map_fn() function we apply tf.gather().

  • We stack back the elements we get with tf.stack()

import tensorflow as tf
import numpy as np

def custom_gather_v2(a, b):
    def apply_gather(x):
        return tf.gather(x[0], tf.cast(x[1], tf.int32))
    a = tf.cast(a, dtype=tf.float32)
    b = tf.cast(b, dtype=tf.float32)
    stacked = tf.stack([a, b], axis=1)
    gathered = tf.map_fn(apply_gather, stacked)
    return tf.stack(gathered, axis=0)

a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
b = np.array([[0, 1, 1], [2, 1, 0]], dtype=np.int32)

x = tf.placeholder(tf.float32, shape=(None, 3))
y = tf.placeholder(tf.int32, shape=(None, 3))

gathered = custom_gather_v2(x, y)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gathered, feed_dict={x:a, y:b}))
# [[1. 2. 2.]
#  [6. 5. 4.]]

Upvotes: 3

Related Questions