John M.
John M.

Reputation: 875

How to use pack_padded_sequence with multiple variable-length input with the same label in pytorch

I have a model which takes three variable-length inputs with the same label. Is there a way I could use pack_padded_sequence somehow? If so, how should I sort my sequences?

For example,

a = (([0,1,2], [3,4], [5,6,7,8]), 1) # training data is in length 3,2,4; label is 1
b = (([0,1], [2], [6,7,8,9,10]), 1)

Both a and b will be fed into three separated LSTMs and the result will be merged to predict the target.

Upvotes: 5

Views: 10376

Answers (2)

Victor Zuanazzi
Victor Zuanazzi

Reputation: 1974

The answer above is already quite informative. Though I often find myself having problems understanding pytorch's documentation. I created those two functions to help me with the pack padding pad packing think.

def batch_to_sequence(x, len_x, batch_first):
    """helpful function to do the pack padding shit
    returns the pack_padded sequence, whatever that is.
    The data does NOT have to be sorted by sentence lenght, we do that for you!
    Input:
        x: (torch.tensor[max_len, batch, embedding_dim]) tensor containing the  
            padded data. It expects the embeddings of the words in the sequence 
            they happen in the sentence.  If batch_first == True, then the 
            max_len and batch dimensions are transposed.
        len_x: (torch.tensor[batch]) a tensor containing the length of each 
            sentence in x.
        batch_first: (bool), indicates whether batch or sentence lenght are 
            indexed in the first dimension of the tensor.
    Output:
        x: (torch pack padded sequence) a the pad packed sequence containing 
            the data. (The documentation is horrible, I don't know what a 
            pack padded sequence really is.)
        idx: (torch.tensor[batch]), the indexes used to sort x, this index in 
            necessary in sequence_to_batch.
        len_x: (torch.tensor[batch]) the sorted lenghs, also needed for 
            sequence_to_batch."""

    #sort data because pack_padded is too stupid to do it itself
    len_x, idx = len_x.sort(0, descending=True)
    x = x[:,idx]

    #remove paddings before feeding it to the LSTM
    x = torch.nn.utils.rnn.pack_padded_sequence(x, 
                                                len_x, 
                                                batch_first = batch_first)

    return x, len_x, idx

and

def sequence_to_batch(x, len_x, idx, output_size, batch_first, all_hidden = False):
    """helpful function for the pad packed shit.
    Input:
        x: (packed pad sequence) the ouptut of lstm  or pack_padded_sequence().
        len_x (torch.tensor[batch]), the sorted leghths that come out of 
            batch_to_sequence().
        idx: (torch.tenssor[batch]), the indexes used to sort len_x
        output_size: (int), the expected dimension of the output embeddings.
        batch_first: (bool), indicates whether batch or sentence lenght are 
            indexed in the first dimension of the tensor.
        all_hidden: (bool), if False returs the last relevant hidden state - it 
            ignores the hidden states produced by the padding. If True, returs
            all hidden states.
    Output:
        x: (torch.tensor[batch, embedding_dim]) tensor containing the  
            padded data.           
    """

    #re-introduce the paddings
    #doc pack_padded_sequence:
    #https://pytorch.org/docs/master/nn.html#torch.nn.utils.rnn.pack_padded_sequence
    x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, 
                                                  batch_first = batch_first)
    if all_hidden:
        return x

    #get the indexes of the last token (where the lstm should stop)
    longest_sentence = max(len_x)
    #subtracsts -1 to see what happens
    last_word = [i*longest_sentence + len_x[i] for i in range(len(len_x))]

    #get the relevant hidden states
    x = x.view(-1, output_size)
    x = x[last_word,:]

    #unsort the batch!
    _, idx = idx.sort(0, descending=False)
    x = x[idx, :]

    return x

you can use them inside the forward pass of your lstm

def forward(self, x, len_x):

        #convert batch into a packed_pad sequence
        x, len_x, idx = batch_to_sequence(x, len_x, self.batch_first)

        #run LSTM, 
        x, (_, _) = self.uni_lstm(x)

        #takes the pad_packed_sequence and gives you the embedding vectors
        x = sequence_to_batch(x, len_x, idx, self.output_size, self.batch_first)        
        return x

Upvotes: 0

Wasi Ahmad
Wasi Ahmad

Reputation: 37701

Let's do it step by step.

Input Data Processing

a = (([0,1,2], [3,4], [5,6,7,8]), 1)

# store length of each element in an array
len_a = np.array([len(a) for a in a[0]]) 
variable_a  = np.zeros((len(len_a), np.amax(len_a)))
for i, a in enumerate(a[0]):
    variable_a[i, 0:len(a)] = a

vocab_size = len(np.unique(variable_a))
Variable(torch.from_numpy(variable_a).long())
print(variable_a)

It prints:

Variable containing:
 0  1  2  0
 3  4  0  0
 5  6  7  8
[torch.DoubleTensor of size 3x4]

Defining embedding and RNN layer

Now, let's say, we have an Embedding and RNN layer class as follows.

class EmbeddingLayer(nn.Module):

    def __init__(self, input_size, emsize):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(input_size, emsize)

    def forward(self, input_variable):
        return self.embedding(input_variable)


class Encoder(nn.Module):

    def __init__(self, input_size, hidden_size, bidirection):
        super(Encoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bidirection = bidirection
        self.rnn = nn.LSTM(self.input_size, self.hidden_size, batch_first=True, 
                                    bidirectional=self.bidirection)

    def forward(self, sent_variable, sent_len):
        # Sort by length (keep idx)
        sent_len, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len)
        idx_unsort = np.argsort(idx_sort)

        idx_sort = torch.from_numpy(idx_sort)
        sent_variable = sent_variable.index_select(0, Variable(idx_sort))

        # Handling padding in Recurrent Networks
        sent_packed = nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True)
        sent_output = self.rnn(sent_packed)[0]
        sent_output = nn.utils.rnn.pad_packed_sequence(sent_output, batch_first=True)[0]

        # Un-sort by length
        idx_unsort = torch.from_numpy(idx_unsort)
        sent_output = sent_output.index_select(0, Variable(idx_unsort))

        return sent_output

Embed and encode the processed input data

We can embed and encode our input as follows.

emb = EmbeddingLayer(vocab_size, 50)
enc = Encoder(50, 100, False, 'LSTM')

emb_a = emb(variable_a)
enc_a = enc(emb_a, len_a)

If you print the size of enc_a, you will get torch.Size([3, 4, 100]). I hope you understand the meaning of this shape.

Please note, the above code runs only on CPU.

Upvotes: 9

Related Questions