Wells
Wells

Reputation: 9

Extracting/gathering elements from a tensor according to two index vectors along with one dimension in Tensorflow

I'm sorry for the long, tedious question title which is hard to understand. Basically, I'd like to implement a function in tensorflow:

e.g. For a tensor A with dimension [10, 10, 7, 1], and an index matrix B = array([[1,3,5],[2,4,6]]). I'd like to extract the elements in A along with axis = 2 (following Python convention, A has 0,1,2,3 four axes) according to the indices in each row of B.

So the results of the example should be a tensor C with dimension [10, 10, 3, 2], where the third dimension is due to selecting elements in A along axis=2 according to indices [1,3,5] or [2,4,6], and the fourth dimension is equal to the first dimension of B (i.e. number of rows of B here), since we did two selections here along that dimension.

Any "tensor favor" clue to implement this in tensorflow, instead of doing it in two steps? I didn't see a way of using tf.gather_nd() or tf.gather() for that. Any idea? Many thanks!

An additional example:

A = [[[1],     # A is (3, 5, 1)
       [2],
       [3],
       [4],
       [5]]],
     [[[10],
       [20],
       [30],
       [40],
       [50]]],
     [[[100],
       [200],
       [300],
       [400],
       [500]]]

B = [[1,4,3],     # B is (2,3)
     [2,3,5]]

C = [[[1, 2],     # C is (3, 3, 2)
       [4, 3],
       [3, 5]]],
     [[[10, 20],
       [40, 30],
       [30, 50]]],
     [[[100, 200],
       [400, 300],
       [300, 500]]]

Upvotes: 0

Views: 1112

Answers (1)

Patwie
Patwie

Reputation: 4450

The shape of your B tensor looks wrong and your question is hard to parse. But anyway, TF is not very elegant at this problem. It requires a very specific shape of B. Try something similar to

import tensorflow as tf
import numpy as np

A = np.random.randn(10, 10, 7, 1).astype(np.float32)
A[0, 0, 1, 0] = 100001
A[0, 0, 3, 0] = 100002
A[0, 0, 5, 0] = 100003
A[0, 0, 2, 0] = 100004
A[0, 0, 4, 0] = 100005
A[0, 0, 6, 0] = 100006
A = tf.convert_to_tensor(A)

sess = tf.InteractiveSession()

B = np.array([
    [1, 3, 5],
    [2, 4, 6]
])

B = tf.convert_to_tensor(B)
B = tf.reshape(B, [-1])
B = tf.concat([tf.zeros_like(B), tf.zeros_like(B), B, tf.zeros_like(B)], axis=-1)
B = tf.reshape(B, [4, -1])
B = tf.transpose(B, [1, 0])
B = tf.reshape(B, [1, 2, 3, -1])


C = tf.gather_nd(A, B)
C = sess.run(C)
print C.shape
print C

output is

[[[100001. 100002. 100003.]
  [100004. 100005. 100006.]]]

Upvotes: 1

Related Questions