Reputation: 11200
I have an input tensor ids input_ids
with shape: [B x T]
and a corresponding embedding matrix with shape [B x T x D]
(B: Batch size, T: Sequence Length, D: Dimension)
. The input ids are vocabulary ids and the embedding matrix contains to corresponding embeddings.
From the embedding matrix I want to select those elements with certain ids (e.g., 103
). It would be easy to do this using tf.where
and tf.gather_nd
but what I don't know how to do, is to organize the results in a batch of size [B x N x D]
where N
is the maximum number of tokens with that id (103
) in the sequence. I want to use 0 tensors as padding as needed.
A code might show it better (lets say B=2, T=8, and D=3
):
import tensorflow as tf
tf.enable_eager_execution()
input_ids = tf.constant([[ 101, 1996, 16360, 103, 1010, 1996, 4223, 1997],
[ 101, 103, 3793, 103, 2443, 2000, 103, 2469]])
embeddings = tf.random_normal((2,8,3))
# input ids have two sequences. first one has one 103 element, while second has 3.
I want to select from embeddings
those that correspond to input_ids==103
and pad the remaining results with zeros.
I can get this with:
indices= tf.where(tf.equal(input_ids, 103))
result = tf.gather_nd(indices=indices, params=embeddings)
#result.shape==[4x3]
# This will result in a [4x3] matrix where 4 = total number of 103 elements in the batch
# and 3 is their corresponding embeddings dimension
# Now I want to organize this into a batch of the
# same batch size as input, i.e., desired shape=(2x3)
# where first (1x3) row contains all token `103`'s embeddings
# in the first sequence but but second (1x3) row has only
# one token 103 embedding (second sequence has only one 103 token)
# the rest are padded with zeros.
In general, this will result in a [M x D]
tensor (M=total number of 103 tokens in the batch). What I want is the [B x N x D]
where (N=maximum number of 103 tokens in each sequence, for the above case it is 3). I hope the description is clear (kind of hard to explain the exact problem).
How can I achieve this?
Upvotes: 0
Views: 599
Reputation: 6176
I think it can take advantage of the property that tf.gather_nd
returns 0
when parameter indices
is negative.
First get the indices value of certain ids in embeddings
.
import tensorflow as tf
tf.enable_eager_execution()
input_ids = tf.constant([[ 101, 1996, 16360, 103, 1010, 1996, 4223, 1997],
[ 101, 103, 3793, 103, 2443, 2000, 103, 2469]])
embeddings = tf.random_normal((2,8,3))
condition = tf.equal(input_ids, 103)
indices_value= tf.where(condition)
# [[0 3]
# [1 1]
# [1 3]
# [1 6]]
Then we should get number of tokens for every sequences and the mask of indices value.
length = tf.reduce_sum(tf.cast(condition,tf.int32),axis=-1)
# [1 3]
indices_mask = tf.sequence_mask(length,tf.reduce_max(length))
# [[ True False False]
# [ True True True]]
Next we need to specify the location of indices value in each sequence.
result_indices = tf.scatter_nd(tf.where(indices_mask),
indices_value+1,
(indices_mask.shape[0],indices_mask.shape[1],tf.rank(input_ids)))-1
# [[[ 0 3]
# [-1 -1]
# [-1 -1]]
#
# [[ 1 1]
# [ 1 3]
# [ 1 6]]]
Finally we get the result by tf.gather_nd
.
result = tf.gather_nd(indices=result_indices, params=embeddings)
print(result)
# [[[ 1.22885 0.77642244 -0.82193506]
# [ 0. 0. 0. ]
# [ 0. 0. 0. ]]
#
# [[-0.0567691 0.07378497 -0.4799046 ]
# [-1.1627238 -1.994217 0.8443906 ]
# [ 0.776338 -0.25828102 -1.7915782 ]]]
Upvotes: 1