Reputation: 471
I have two tensors.
v, shape=(50, 64, 128), dtype=float32
m, shape=(64, 50, 1), dtype=int32
Values in m are integers between 0 and 50 (<=49)
I want to use values of m, to get a specific tensor of v for the same index out of 64.
Resulting tensor is r: shape=(64, 50, 128), dtype=float32
For example values for r(i, j, 0-128) = v(m(i, j), i, 0-128)
The closest thing I see is tf.nn.embedding_lookup but I'm not sure how to use it for this use case
Upvotes: 0
Views: 99
Reputation: 6166
You can use the following tf.nn.embedding_lookup
or tf.gather_nd
methods to achieve your goals.
import tensorflow as tf
import numpy as np
m_np = np.random.randint(0,50,(64, 50, 1))
m = tf.constant(m_np)
n = tf.random.normal((50, 64, 128))
# Method 1
tmp = tf.nn.embedding_lookup(n,m[:,:,0]) # shape=(64,50,64,128)
tmp = tf.transpose(tmp,[1,3,0,2]) # shape=(50,128,64,64)
result1 = tf.transpose(tf.matrix_diag_part(tmp),[2,0,1]) # shape=(64,50,128)
# Method 2
indices = tf.tile(tf.reshape(tf.range(64),(-1,1,1)),(1,50,1)) # shape=(64,50,1)
indices = tf.concat([m,indices],axis=-1) # shape=(64,50,2)
result2 = tf.gather_nd(n,indices) # shape=(64,50,128)
with tf.Session() as sess:
# Randomly select a location for test
n_value,result_value = sess.run([n,result1])
print((n_value[m_np[5,4],5,:]==result_value[5,4]).all())
# True
Upvotes: 1