Fjolfrin
Fjolfrin

Reputation: 33

How to handle hidden-cell output of 2-layer LSTM in PyTorch?

I have made a network with a LSTM and a fully connected layer in PyTorch. I want to test how an increase in the LSTM layers affects my performance.

Say my input is (6, 9, 14), meaning batch size 6, sequence size 9, and feature size 14, and I'm working on a task that has 6 classes, so I expect a 6-element one-hot-encoded tensor as the prediction for a single sequence. The output of this network after the FC layer should be (6, 6), however, if I use 2 LSTM layers it becomes (12, 6).

I don't understand how I should handle the output of the LSTM layer to decrease the number of batches from [2 * batch_size] to [batch_size]. Also, I know I'm using the hidden state as the input to the FC layer, I want to try it this way for now.

Should I sum or concatenate every two batches or anything else?? Cheers!

    def forward(self, x):
        hidden_0 = torch.zeros((self.lstm_layers, x.size(0), self.hidden_size), dtype=torch.double, device=self.device)
        cell_0 = torch.zeros((self.lstm_layers, x.size(0), self.hidden_size), dtype=torch.double, device=self.device)

        y1, (hidden_1, cell_1) = self.lstm(x, (hidden_0, cell_0))
        hidden_1 = hidden_1.view(-1, self.hidden_size)

        y = self.linear(hidden_1)

        return y

Upvotes: 0

Views: 2083

Answers (2)

Robert Deibel
Robert Deibel

Reputation: 26

The hidden state shape of a multi layer lstm is (layers, batch_size, hidden_size) see output LSTM. It contains the hidden state for each layer along the 0th dimension.

In your example you convert the shape into two dimensions here:

hidden_1 = hidden_1.view(-1, self.hidden_size)

this transforms the shape into (batch_size * layers, hidden_size).

What you would want to do is only use the hidden state of the last layer:

hidden = hidden_1[-1,:,:].view(-1, self.hidden_size)  # (1, bs, hidden) -> (bs, hidden)
y = self.linear(hidden)
return y

Upvotes: 1

ki-ljl
ki-ljl

Reputation: 509

For a multi-layer LSTM, you can write it like this:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        seq_len = input_seq.shape[1]
        # input(batch_size, seq_len, input_size)
        input_seq = input_seq.view(self.batch_size, seq_len, 1)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0))
        # print('output.size=', output.size())
        # print(self.batch_size * seq_len, self.hidden_size)
        output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size)
        pred = self.linear(output)
        # print('pred=', pred.shape)
        pred = pred.view(self.batch_size, seq_len, -1)
        pred = pred[:, -1, :]
        return pred

Upvotes: 0

Related Questions