gdelab
gdelab

Reputation: 6220

Extracting specific elements from a tensor in tensorflow

I'm using tensorflow on python I have a data tensor of shape [?, 5, 37], and a idx tensor of shape [?, 5]

I'd like to extract elements from data and get an output of shape [?, 5] such that:

output[i][j] = data[i][j][idx[i, j]] for all i in range(?) and j in range(5)

It looks loke the tf.gather_nd() function is the closest to my needs, but I don't see how to use it it my case...

Thanks !

EDIT : I managed to do it with gather_nd as shown below, but is there a better option ? (it seems a bit heavy-handed)

    nRows = tf.shape(length_label)[0] ==> ?
    nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) ==> 5
    m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
                                           shape=[nRows, nCols])
    m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
                                            shape=[nCols, nRows]))
    indices = tf.pack([m2, m1, idx], axis=-1)
    # indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
    output = tf.gather_nd(data, indices=indices)

Upvotes: 1

Views: 2529

Answers (1)

gdelab
gdelab

Reputation: 6220

I managed to do it with gather_nd as shown below

nRows = tf.shape(length_label)[0] # ==> ?
nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) # ==> 5
m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
                                       shape=[nRows, nCols])
m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
                                        shape=[nCols, nRows]))
indices = tf.pack([m2, m1, idx], axis=-1)
# indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
output = tf.gather_nd(data, indices=indices)

Upvotes: 2

Related Questions