Reputation: 9
I'm sorry for the long, tedious question title which is hard to understand. Basically, I'd like to implement a function in tensorflow:
e.g. For a tensor A with dimension [10, 10, 7, 1], and an index matrix B = array([[1,3,5],[2,4,6]]). I'd like to extract the elements in A along with axis = 2 (following Python convention, A has 0,1,2,3 four axes) according to the indices in each row of B.
So the results of the example should be a tensor C with dimension [10, 10, 3, 2], where the third dimension is due to selecting elements in A along axis=2 according to indices [1,3,5] or [2,4,6], and the fourth dimension is equal to the first dimension of B (i.e. number of rows of B here), since we did two selections here along that dimension.
Any "tensor favor" clue to implement this in tensorflow, instead of doing it in two steps? I didn't see a way of using tf.gather_nd() or tf.gather() for that. Any idea? Many thanks!
An additional example:
A = [[[1], # A is (3, 5, 1)
[2],
[3],
[4],
[5]]],
[[[10],
[20],
[30],
[40],
[50]]],
[[[100],
[200],
[300],
[400],
[500]]]
B = [[1,4,3], # B is (2,3)
[2,3,5]]
C = [[[1, 2], # C is (3, 3, 2)
[4, 3],
[3, 5]]],
[[[10, 20],
[40, 30],
[30, 50]]],
[[[100, 200],
[400, 300],
[300, 500]]]
Upvotes: 0
Views: 1112
Reputation: 4450
The shape of your B
tensor looks wrong and your question is hard to parse. But anyway, TF is not very elegant at this problem. It requires a very specific shape of B. Try something similar to
import tensorflow as tf
import numpy as np
A = np.random.randn(10, 10, 7, 1).astype(np.float32)
A[0, 0, 1, 0] = 100001
A[0, 0, 3, 0] = 100002
A[0, 0, 5, 0] = 100003
A[0, 0, 2, 0] = 100004
A[0, 0, 4, 0] = 100005
A[0, 0, 6, 0] = 100006
A = tf.convert_to_tensor(A)
sess = tf.InteractiveSession()
B = np.array([
[1, 3, 5],
[2, 4, 6]
])
B = tf.convert_to_tensor(B)
B = tf.reshape(B, [-1])
B = tf.concat([tf.zeros_like(B), tf.zeros_like(B), B, tf.zeros_like(B)], axis=-1)
B = tf.reshape(B, [4, -1])
B = tf.transpose(B, [1, 0])
B = tf.reshape(B, [1, 2, 3, -1])
C = tf.gather_nd(A, B)
C = sess.run(C)
print C.shape
print C
output is
[[[100001. 100002. 100003.]
[100004. 100005. 100006.]]]
Upvotes: 1