Milad
Milad

Reputation: 471

Embedding lookup from a specific axis

I have two tensors.

v, shape=(50, 64, 128), dtype=float32
m, shape=(64, 50, 1), dtype=int32

Values in m are integers between 0 and 50 (<=49) I want to use values of m, to get a specific tensor of v for the same index out of 64. Resulting tensor is r: shape=(64, 50, 128), dtype=float32

For example values for r(i, j, 0-128) = v(m(i, j), i, 0-128)

The closest thing I see is tf.nn.embedding_lookup but I'm not sure how to use it for this use case

Upvotes: 0

Views: 99

Answers (1)

giser_yugang
giser_yugang

Reputation: 6166

You can use the following tf.nn.embedding_lookup or tf.gather_nd methods to achieve your goals.

import tensorflow as tf
import numpy as np

m_np = np.random.randint(0,50,(64, 50, 1))
m = tf.constant(m_np)
n = tf.random.normal((50, 64, 128))

# Method 1
tmp = tf.nn.embedding_lookup(n,m[:,:,0]) # shape=(64,50,64,128)
tmp = tf.transpose(tmp,[1,3,0,2]) # shape=(50,128,64,64)
result1 = tf.transpose(tf.matrix_diag_part(tmp),[2,0,1]) # shape=(64,50,128)

# Method 2
indices = tf.tile(tf.reshape(tf.range(64),(-1,1,1)),(1,50,1)) # shape=(64,50,1)
indices = tf.concat([m,indices],axis=-1) # shape=(64,50,2)
result2 = tf.gather_nd(n,indices) # shape=(64,50,128)

with tf.Session() as sess:
    # Randomly select a location for test
    n_value,result_value = sess.run([n,result1])
    print((n_value[m_np[5,4],5,:]==result_value[5,4]).all())

# True

Upvotes: 1

Related Questions