CentAu
CentAu

Reputation: 11200

Tensorflow, select embeddings from input with specific token id and batch the results

I have an input tensor ids input_ids with shape: [B x T] and a corresponding embedding matrix with shape [B x T x D] (B: Batch size, T: Sequence Length, D: Dimension). The input ids are vocabulary ids and the embedding matrix contains to corresponding embeddings.

From the embedding matrix I want to select those elements with certain ids (e.g., 103). It would be easy to do this using tf.where and tf.gather_nd but what I don't know how to do, is to organize the results in a batch of size [B x N x D] where N is the maximum number of tokens with that id (103) in the sequence. I want to use 0 tensors as padding as needed.

A code might show it better (lets say B=2, T=8, and D=3):

import tensorflow as tf
tf.enable_eager_execution()

input_ids = tf.constant([[  101,  1996, 16360,  103,  1010,  1996,  4223,  1997],
                        [  101,  103,  3793,  103,  2443,  2000,  103,  2469]])
embeddings = tf.random_normal((2,8,3))

# input ids have two sequences. first one has one 103 element, while second has 3. 

I want to select from embeddings those that correspond to input_ids==103 and pad the remaining results with zeros. I can get this with:

indices=  tf.where(tf.equal(input_ids, 103))
result = tf.gather_nd(indices=indices, params=embeddings)
#result.shape==[4x3]

# This will result in a [4x3] matrix where 4 = total number of 103 elements in the batch 
# and 3 is their corresponding embeddings dimension
# Now I want to organize this into a batch of the 
# same batch size as input, i.e., desired shape=(2x3)
# where first (1x3) row contains all token `103`'s embeddings
# in the first sequence but but second (1x3) row has only 
# one token 103 embedding (second sequence has only one 103 token)
# the rest are padded with zeros.

In general, this will result in a [M x D] tensor (M=total number of 103 tokens in the batch). What I want is the [B x N x D] where (N=maximum number of 103 tokens in each sequence, for the above case it is 3). I hope the description is clear (kind of hard to explain the exact problem).

How can I achieve this?

Upvotes: 0

Views: 599

Answers (1)

giser_yugang
giser_yugang

Reputation: 6176

I think it can take advantage of the property that tf.gather_nd returns 0 when parameter indices is negative.

First get the indices value of certain ids in embeddings.

import tensorflow as tf
tf.enable_eager_execution()

input_ids = tf.constant([[  101,  1996, 16360,  103,  1010,  1996,  4223,  1997],
                        [  101,  103,  3793,  103,  2443,  2000,  103,  2469]])
embeddings = tf.random_normal((2,8,3))

condition = tf.equal(input_ids, 103)
indices_value=  tf.where(condition)
# [[0 3]
#  [1 1]
#  [1 3]
#  [1 6]]

Then we should get number of tokens for every sequences and the mask of indices value.

length = tf.reduce_sum(tf.cast(condition,tf.int32),axis=-1)
# [1 3]
indices_mask = tf.sequence_mask(length,tf.reduce_max(length))
# [[ True False False]
#  [ True  True  True]]

Next we need to specify the location of indices value in each sequence.

result_indices = tf.scatter_nd(tf.where(indices_mask),
                               indices_value+1,
                               (indices_mask.shape[0],indices_mask.shape[1],tf.rank(input_ids)))-1
# [[[ 0  3]
#   [-1 -1]
#   [-1 -1]]
#
#  [[ 1  1]
#   [ 1  3]
#   [ 1  6]]]

Finally we get the result by tf.gather_nd.

result = tf.gather_nd(indices=result_indices, params=embeddings)
print(result)
# [[[ 1.22885     0.77642244 -0.82193506]
#   [ 0.          0.          0.        ]
#   [ 0.          0.          0.        ]]
# 
#  [[-0.0567691   0.07378497 -0.4799046 ]
#   [-1.1627238  -1.994217    0.8443906 ]
#   [ 0.776338   -0.25828102 -1.7915782 ]]]

Upvotes: 1

Related Questions