Reputation: 81
I would like to select a part of this tensor.
A = tf.constant([[[1,1],[2,2],[3,3]], [[4,4],[5,5],[6,6]]])
The output of A will be
[[[1 1]
[2 2]
[3 3]]
[[4 4]
[5 5]
[6 6]]]
The index I want to select from A is [1, 0]. I mean [2 2] of the first part and [4 4] of the second part of this tensor, so my expected result is
[2 2]
[4 4]
How can I do it with embedding_lookup function?
B = tf.nn.embedding_lookup(A, [1, 0])
I have already tried this
but it's not my expectation.
[[[4 4]
[5 5]
[6 6]]
[[1 1]
[2 2]
[3 3]]]
Can anyone help me and explain how to do it?
Upvotes: 0
Views: 242
Reputation: 11343
Try the following,
A = tf.constant([[[1,1],[2,2],[3,3]], [[4,4],[5,5],[6,6]]])
B = [1,0]
inds = [(a,b) for a,b in zip(np.arange(len(B)), B)]
C = tf.gather_nd(params=A, indices=inds)
Upvotes: 1