Jindřich
Jindřich

Reputation: 11220

Pytorch equivalent of `tf.reverse_sequence`?

I would like to do backward-direction LSTM on a padded sequence, which requires reversing the input sequence without the padding.

For a batch like this (where _ stands for padding):

a b c _ _ _
d e f g _ _
h i j k l m

if would like to get:

c b a _ _ _
g f e d _ _
m l k j i h

TensorFlow has a function tf.reverse_sequence that takes the input tensor and lengths of the sequences in the batch and returns the reversed batch. Is there an easy way of doing it in Pytorch?

Upvotes: 3

Views: 2397

Answers (1)

dennlinger
dennlinger

Reputation: 11420

Unfortunately, there is no direct equivalent yet, although it has been requested.

I also looked into the whole PackedSequence object, but it has no .flip() operation defined on it. Assuming you already have the necessary data to provide the lengths, as you suggested, you could implement it with this function:

def flipBatch(data, lengths):
    assert data.shape[0] == len(lengths), "Dimension Mismatch!"
    for i in range(data.shape[0]):
        data[i,:lengths[i]] = data[i,:lengths[i]].flip(dims=[0])

    return data

Unfortunately, this only works if your sequence is two-dimensional (with batch_size x sequence), but you could easily extend this for your specific input requirements. This already more or less covers the proposal in the above link, but I updated it to today's standard.

Upvotes: 5

Related Questions