N. Virgo
N. Virgo

Reputation: 8439

What's the difference between "hidden" and "output" in PyTorch LSTM?

I'm having trouble understanding the documentation for PyTorch's LSTM module (and also RNN and GRU, which are similar). Regarding the outputs, it says:

Outputs: output, (h_n, c_n)

  • output (seq_len, batch, hidden_size * num_directions): tensor containing the output features (h_t) from the last layer of the RNN, for each t. If a torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.
  • h_n (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t=seq_len
  • c_n (num_layers * num_directions, batch, hidden_size): tensor containing the cell state for t=seq_len

It seems that the variables output and h_n both give the values of the hidden state. Does h_n just redundantly provide the last time step that's already included in output, or is there something more to it than that?

Upvotes: 117

Views: 56613

Answers (6)

Nor.Z
Nor.Z

Reputation: 1349

This Post is for: understanding the relation bt hidden_size & output_size + h_t vs out + value visualize in Rnn.

(disclaimer: This is just my personal understanding. I can be wrong.)


  • Even though, in math, the output_size of y should be customizable base on the shape of the weight matrix of y.

  • @note: [h_t vs out]

    Though I said, >"Pytroch decided to treat the output y == hidden state h".
    The actual output has slight difference.

    • h_t : This is the final hidden state after processing all time steps.

    • out : contains output at each time step

    • You can try & see the values are indeed the same. (Which many other answer posts have done.)

              out, h_t = self.rnn(x, h0)
              for i in range(0, h_t.shape[1]):
                  # for every batch, the tensor in the last time step of {y_t} output == the tensor in the last layer of h_t hidden_state
                  if not torch.equal(out[i, -1], h_t[-1, i]):
      
  • Here is one real output from out, h_t = self.rnn(x, h0) # print(out) # print(h_t).

    • Watch this row: [ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098], // << eg see this line, they are same
      You can find each row tensor in the h_t in the last row tensor of out.

    • output & hidden_state

      • there is no corresponding rows in `output` for layer_0 `hidden_state`
        
        hidde_state shape: (amountOfLayer, batch_size, hidden_size)  
        if `batch_size` is just 1, then this whole blue block contains just one row vector.
        
        output shape:
        

(batch_size, sequence_length, hidden_size) ```

  • >--
    tensor([[[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.093,  0.082,  0.231, -0.073,  0.143,  ..., -0.067, -0.138,  0.016,  0.007,  0.150],
            [ 0.098,  0.067,  0.218, -0.055,  0.120,  ..., -0.082, -0.096,  0.034, -0.005,  0.149],
            [ 0.158,  0.026,  0.242, -0.053,  0.095,  ..., -0.100, -0.105,  0.032, -0.022,  0.121],
            [ 0.184,  0.051,  0.244, -0.048,  0.115,  ..., -0.125, -0.123,  0.023,  0.007,  0.129],
            [ 0.147,  0.050,  0.239, -0.025,  0.126,  ..., -0.100, -0.135,  0.031, -0.027,  0.131]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.122,  0.021,  0.088, -0.125,  0.089,  ..., -0.045, -0.003,  0.064,  0.024,  0.040],
            ...,
            [ 0.071, -0.062,  0.179, -0.167, -0.020,  ..., -0.195, -0.090,  0.108,  0.166,  0.113],
            [ 0.137, -0.026,  0.259, -0.002,  0.045,  ..., -0.156, -0.108,  0.046,  0.049,  0.138],
            [ 0.122, -0.019,  0.159,  0.013,  0.085,  ..., -0.098, -0.046,  0.006, -0.002,  0.072],
            [ 0.133,  0.047,  0.095,  0.011,  0.123,  ..., -0.081, -0.055,  0.033, -0.026,  0.107],
            [ 0.090,  0.078,  0.068, -0.047,  0.153,  ..., -0.052, -0.063,  0.069, -0.030,  0.119]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.064,  0.077, -0.109, -0.184,  0.064,  ...,  0.100, -0.053,  0.139,  0.047,  0.058],
            [ 0.062,  0.056, -0.136, -0.172,  0.081,  ...,  0.073, -0.063,  0.159,  0.054,  0.100],
            [ 0.026,  0.006, -0.107, -0.160,  0.105,  ...,  0.041, -0.043,  0.172,  0.049,  0.113],
            [ 0.087, -0.002,  0.011, -0.106,  0.083,  ..., -0.021, -0.031,  0.087,  0.058,  0.047],
            [ 0.099,  0.001,  0.045, -0.068,  0.126,  ..., -0.043, -0.028,  0.099,  0.008,  0.101]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.046,  0.065, -0.045, -0.082,  0.068,  ...,  0.069, -0.037,  0.106,  0.019,  0.102],
            [ 0.040,  0.045, -0.063, -0.106,  0.046,  ...,  0.068, -0.051,  0.122,  0.023,  0.072],
            [ 0.101,  0.050, -0.019, -0.106,  0.068,  ...,  0.007, -0.035,  0.119,  0.045,  0.034],
            [ 0.126,  0.035,  0.041, -0.078,  0.091,  ..., -0.039, -0.030,  0.111,  0.004,  0.087],
            [ 0.125,  0.028,  0.066, -0.058,  0.122,  ..., -0.053, -0.031,  0.084, -0.016,  0.088]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.090,  0.034, -0.028, -0.124,  0.030,  ...,  0.071, -0.059,  0.109,  0.089,  0.058],
            [ 0.109,  0.058, -0.057, -0.142,  0.042,  ...,  0.064, -0.043,  0.122,  0.082,  0.047],
            [ 0.091,  0.036, -0.018, -0.094,  0.039,  ...,  0.036, -0.003,  0.126,  0.028,  0.066],
            [ 0.122,  0.046,  0.041, -0.068,  0.088,  ..., -0.023, -0.023,  0.114,  0.007,  0.052],
            [ 0.127,  0.034,  0.056, -0.061,  0.126,  ..., -0.047, -0.022,  0.102, -0.010,  0.079]],
    
            ...,
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [-0.042, -0.062,  0.058, -0.124,  0.102,  ...,  0.064, -0.063,  0.070,  0.269,  0.059],
            [ 0.005, -0.011, -0.033, -0.105,  0.090,  ...,  0.051, -0.037,  0.088,  0.175,  0.086],
            [ 0.050, -0.022, -0.007, -0.003,  0.033,  ...,  0.051, -0.002,  0.010,  0.080,  0.041],
            [ 0.107,  0.045,  0.001, -0.022,  0.127,  ..., -0.013, -0.072,  0.049,  0.027,  0.059],
            [ 0.121,  0.076,  0.010, -0.044,  0.148,  ..., -0.042, -0.076,  0.044,  0.004,  0.082]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.106,  0.021,  0.173, -0.125,  0.190,  ..., -0.045, -0.106,  0.089,  0.073,  0.136],
            [ 0.077,  0.070,  0.030, -0.137,  0.173,  ...,  0.030, -0.067,  0.129,  0.023,  0.162],
            [ 0.067,  0.056,  0.031, -0.074,  0.039,  ...,  0.057, -0.016,  0.059,  0.012,  0.073],
            [ 0.095,  0.077,  0.036, -0.066,  0.120,  ...,  0.021, -0.044,  0.117, -0.004,  0.073],
            [ 0.125,  0.057,  0.063, -0.060,  0.109,  ..., -0.030, -0.051,  0.102, -0.016,  0.085]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.048,  0.031, -0.103, -0.133,  0.037,  ...,  0.069, -0.058,  0.122,  0.030,  0.073],
            [ 0.043, -0.006, -0.082, -0.113,  0.040,  ...,  0.040, -0.053,  0.122,  0.032,  0.083],
            [ 0.088,  0.013, -0.004, -0.104,  0.083,  ..., -0.036, -0.046,  0.100,  0.048,  0.055],
            [ 0.115,  0.017,  0.061, -0.080,  0.102,  ..., -0.054, -0.037,  0.096, -0.002,  0.102],
            [ 0.113,  0.023,  0.075, -0.052,  0.118,  ..., -0.050, -0.031,  0.078, -0.022,  0.093]],
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.091, -0.014,  0.236, -0.042,  0.136,  ..., -0.139, -0.112, -0.028,  0.070,  0.103],
            [ 0.095, -0.008,  0.239, -0.029,  0.136,  ..., -0.127, -0.120, -0.014,  0.055,  0.126],
            [ 0.123,  0.026,  0.213, -0.024,  0.147,  ..., -0.134, -0.105,  0.010,  0.061,  0.121],
            [ 0.127,  0.038,  0.211, -0.009,  0.147,  ..., -0.092, -0.124,  0.040,  0.005,  0.114],
            [ 0.118,  0.049,  0.117, -0.012,  0.138,  ..., -0.074, -0.069,  0.046, -0.004,  0.098]], // << eg see this line, they are same
    
            [[ 0.057,  0.105, -0.045, -0.055,  0.214,  ..., -0.022, -0.023,  0.018,  0.058, -0.041],
            [ 0.069,  0.013,  0.026, -0.040,  0.123,  ..., -0.054,  0.014,  0.114,  0.029,  0.067],
            [ 0.142,  0.028,  0.018, -0.053,  0.123,  ..., -0.046, -0.019,  0.104,  0.024,  0.047],
            [ 0.159,  0.032,  0.028, -0.059,  0.115,  ..., -0.034, -0.002,  0.121, -0.036,  0.056],
            [ 0.132,  0.046,  0.054, -0.046,  0.109,  ..., -0.043, -0.005,  0.098, -0.038,  0.072],
            ...,
            [ 0.152, -0.010,  0.118, -0.057,  0.071,  ..., -0.202,  0.018,  0.039,  0.035,  0.105],
            [ 0.118, -0.042,  0.180, -0.128,  0.109,  ..., -0.099, -0.110,  0.044,  0.063,  0.076],
            [ 0.073, -0.049,  0.173, -0.136,  0.087,  ..., -0.036, -0.123,  0.002,  0.133,  0.087],
            [ 0.069, -0.013,  0.158, -0.002,  0.055,  ..., -0.063, -0.058,  0.019,  0.025,  0.139],
            [ 0.124,  0.043,  0.126,  0.003,  0.102,  ..., -0.062, -0.050,  0.027,  0.027,  0.043]]], device='cuda:0', grad_fn=<CudnnRnnBackward0>)
    tensor([[[ 0.183, -0.012,  0.073, -0.071, -0.109,  ...,  0.137, -0.035, -0.130, -0.083,  0.095],
            [ 0.180,  0.006, -0.001, -0.072, -0.091,  ...,  0.107, -0.048, -0.102, -0.115,  0.042],
            [ 0.172, -0.030, -0.035, -0.092, -0.117,  ...,  0.134, -0.004, -0.097, -0.090,  0.047],
            [ 0.174, -0.006, -0.010, -0.082, -0.100,  ...,  0.104, -0.027, -0.120, -0.122,  0.045],
            [ 0.185,  0.012, -0.005, -0.058, -0.087,  ...,  0.119, -0.025, -0.112, -0.108,  0.044],
            ...,
            [ 0.133,  0.020, -0.024, -0.084, -0.078,  ...,  0.120, -0.073, -0.125, -0.093,  0.026],
            [ 0.162, -0.003, -0.011, -0.075, -0.095,  ...,  0.102, -0.036, -0.110, -0.100,  0.031],
            [ 0.172, -0.020, -0.010, -0.087, -0.097,  ...,  0.100, -0.028, -0.122, -0.131,  0.044],
            [ 0.165,  0.015,  0.037, -0.078, -0.044,  ...,  0.125, -0.020, -0.113, -0.134,  0.039],
            [ 0.124,  0.012, -0.002, -0.105,  0.006,  ...,  0.146, -0.070, -0.155, -0.155,  0.038]],
    
            [[ 0.147,  0.050,  0.239, -0.025,  0.126,  ..., -0.100, -0.135,  0.031, -0.027,  0.131],
            [ 0.090,  0.078,  0.068, -0.047,  0.153,  ..., -0.052, -0.063,  0.069, -0.030,  0.119],
            [ 0.099,  0.001,  0.045, -0.068,  0.126,  ..., -0.043, -0.028,  0.099,  0.008,  0.101],
            [ 0.125,  0.028,  0.066, -0.058,  0.122,  ..., -0.053, -0.031,  0.084, -0.016,  0.088],
            [ 0.127,  0.034,  0.056, -0.061,  0.126,  ..., -0.047, -0.022,  0.102, -0.010,  0.079],
            ...,
            [ 0.121,  0.076,  0.010, -0.044,  0.148,  ..., -0.042, -0.076,  0.044,  0.004,  0.082],
            [ 0.125,  0.057,  0.063, -0.060,  0.109,  ..., -0.030, -0.051,  0.102, -0.016,  0.085],
            [ 0.113,  0.023,  0.075, -0.052,  0.118,  ..., -0.050, -0.031,  0.078, -0.022,  0.093],
            [ 0.118,  0.049,  0.117, -0.012,  0.138,  ..., -0.074, -0.069,  0.046, -0.004,  0.098], // << eg see this line, they are same
            [ 0.124,  0.043,  0.126,  0.003,  0.102,  ..., -0.062, -0.050,  0.027,  0.027,  0.043]]], device='cuda:0', grad_fn=<CudnnRnnBackward0>)
    
  • Following is the a design of RNN,
    for complete code, see https://www.youtube.com/watch?v=0_PgWWmauHk
    -> https://github.com/patrickloeber/pytorch-examples/blob/master/rnn-lstm-gru/main.py

    This uses rnn to predict mnist digits.
    Uses Row-wise Flattening of a 2d mnist matrix.

    With

    • input_size = 28 # H_in - input_size – The number of expected features in the input x # size of each vector (row or col wise)
    • sequence_length = 28 # L - sequence length or the number of time steps

    Uses 2 stacked layers in rnn.

    The output_size is transformed after rnn, at

            self.fc = nn.Linear(hidden_size, num_classes)
            # @shape: (batch_size, 128) -> (batch_size, 10)
            (//...) 
            out = self.fc(out)
    
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(RNN, self).__init__()
            self.num_layers = num_layers
            self.hidden_size = hidden_size
    
            # RNN — PyTorch 2.5 documentation
            # https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, nonlinearity="tanh", bidirectional=False)
            # @shape: 28, 128, 2
    
            # or:
            # self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
            # self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    
            self.fc = nn.Linear(hidden_size, num_classes)
            # @shape: (batch_size, 128) -> (batch_size, 10)
    
        def forward(self, x: Tensor):
            # Set initial hidden states (and cell states for LSTM)
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            # @nota: n = batch_size = 100
            # @shape: x: (batch_size, sequence_length, input_size) = (n, 28, 28)
            # @shape: h0: (2, n, 128)
    
            # or with:
            # c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    
            # Forward propagate RNN // em, no that self loop for recurrent now ..
            out, h_t = self.rnn(x, h0)
            out: Tensor
            # @shape: out: (batch_size, seq_length, hidden_size) = (n, 28, 128)
            # (h_t): Shape: (num_layers, batch_size, hidden_size)
            # print(">--")
            # print(out)
            # print(h_t)
            # for i in range(0, h_t.shape[1]):
            #     # for every batch, the tensor in the last time step of {y_t} output == the tensor in the last layer of h_t hidden_state
            #     if not torch.equal(out[i, -1], h_t[-1, i]):
            #         raise ValueError("")
    
            # or:
            # out, _ = self.lstm(x, (h0,c0))
    
            # Decode the hidden state of the last time step # The `-1` index refers to the last element along the second dimension (sequence length in this case).
            out = out[:, -1, :]
            # @shape: out: (n, 128) # output of the 28th/last time step
    
            out = self.fc(out)
            # @shape: out: (n, 10)
            return out
    

Upvotes: 0

Karthik Ragunath A
Karthik Ragunath A

Reputation: 71

In Pytorch, the output parameter gives the output of each individual LSTM cell in the last layer of the LSTM stack, while hidden state and cell state give the output of each hidden cell and cell state in the LSTM stack in every layer.

import torch.nn as nn
torch.manual_seed(1)
inputs = [torch.randn(1, 3) for _ in range(5)] # indicates that there are 5 sequences to be given as inputs and (1,3) indicates that there is 1 layer with 3 cells
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3)) #initializing h and c values to be of dimensions (1, 1, 3) which indicates there is (1 * 1) - num_layers * num_directions, with batch size of 1 and projection size of 3. 
                                #Since there is only 1 batch in input, h and c can also have only one batch of data for initialization and the number of cells in both input and output should also match.
 
lstm = nn.LSTM(3, 3) #implying both input and output are 3 dimensional data
for i in inputs:
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print('out:', out)
    print('hidden:', hidden)

Output

out: tensor([[[-0.1124, -0.0653,  0.2808]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.1124, -0.0653,  0.2808]]], grad_fn=<StackBackward>), tensor([[[-0.2883, -0.2846,  2.0720]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.1675, -0.0376,  0.4402]]], grad_fn=<StackBackward>)
hidden: (tensor([[[ 0.1675, -0.0376,  0.4402]]], grad_fn=<StackBackward>), tensor([[[ 0.4394, -0.1226,  1.5611]]], grad_fn=<StackBackward>))
out: tensor([[[0.3699, 0.0150, 0.1429]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.3699, 0.0150, 0.1429]]], grad_fn=<StackBackward>), tensor([[[0.8432, 0.0618, 0.9413]]], grad_fn=<StackBackward>))
out: tensor([[[0.1795, 0.0296, 0.2957]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.1795, 0.0296, 0.2957]]], grad_fn=<StackBackward>), tensor([[[0.4541, 0.1121, 0.9320]]], grad_fn=<StackBackward>))
out: tensor([[[0.1365, 0.0596, 0.3931]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.1365, 0.0596, 0.3931]]], grad_fn=<StackBackward>), tensor([[[0.3430, 0.1948, 1.0255]]], grad_fn=<StackBackward>))

Multi-Layered LSTM

import torch.nn as nn
torch.manual_seed(1)
num_layers = 2
inputs = [torch.randn(1, 3) for _ in range(5)] 
hidden = (torch.randn(2, 1, 3),
          torch.randn(2, 1, 3))
lstm = nn.LSTM(input_size=3, hidden_size=3, num_layers=2)
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print('out:', out)
    print('hidden:', hidden)

Output

out: tensor([[[-0.0819,  0.1214, -0.2586]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.2625,  0.4415, -0.4917]],

        [[-0.0819,  0.1214, -0.2586]]], grad_fn=<StackBackward>), tensor([[[-2.5740,  0.7832, -0.9211]],

        [[-0.2803,  0.5175, -0.5330]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1298,  0.2797, -0.0882]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.3818,  0.3306, -0.3020]],

        [[-0.1298,  0.2797, -0.0882]]], grad_fn=<StackBackward>), tensor([[[-2.3980,  0.6347, -0.6592]],

        [[-0.3643,  0.9301, -0.1326]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1630,  0.3187,  0.0728]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.5612,  0.3134, -0.0782]],

        [[-0.1630,  0.3187,  0.0728]]], grad_fn=<StackBackward>), tensor([[[-1.7555,  0.6882, -0.3575]],

        [[-0.4571,  1.2094,  0.1061]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1723,  0.3274,  0.1546]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.5112,  0.1597, -0.0901]],

        [[-0.1723,  0.3274,  0.1546]]], grad_fn=<StackBackward>), tensor([[[-1.4417,  0.5892, -0.2489]],

        [[-0.4940,  1.3620,  0.2255]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1847,  0.2968,  0.1333]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.3256,  0.3217, -0.1899]],

        [[-0.1847,  0.2968,  0.1333]]], grad_fn=<StackBackward>), tensor([[[-1.7925,  0.6096, -0.4432]],

        [[-0.5147,  1.4031,  0.2014]]], grad_fn=<StackBackward>))

Bi-Directional Multi-Layered LSTM

import torch.nn as nn
torch.manual_seed(1)
num_layers = 2
is_bidirectional = True
inputs = [torch.randn(1, 3) for _ in range(5)] 
hidden = (torch.randn(4, 1, 3),
          torch.randn(4, 1, 3)) #4 -> (2 * 2) -> num_layers * num_directions
lstm = nn.LSTM(input_size=3, hidden_size=3, num_layers=2, bidirectional=is_bidirectional)

for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)
    print('out:', out)
    print('hidden:', hidden)
    # output dim -> (seq_len, batch, num_directions * hidden_size) -> (5, 1, 2*3)
    # hidden dim -> (num_layers * num_directions, batch, hidden_size) -> (2 * 2, 1, 3)
    # cell state dim -> (num_layers * num_directions, batch, hidden_size) -> (2 * 2, 1, 3)

Output

out: tensor([[[-0.4620,  0.1115, -0.1087,  0.1646,  0.0173, -0.2196]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.5187,  0.2656, -0.2543]],

        [[ 0.4175,  0.0539,  0.0633]],

        [[-0.4620,  0.1115, -0.1087]],

        [[ 0.1646,  0.0173, -0.2196]]], grad_fn=<StackBackward>), tensor([[[ 1.1546,  0.4012, -0.4119]],

        [[ 0.7999,  0.2632,  0.2587]],

        [[-1.4196,  0.2075, -0.3148]],

        [[ 0.6605,  0.0243, -0.5783]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1860,  0.1359, -0.2719,  0.0815,  0.0061, -0.0980]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.2945,  0.0842, -0.1580]],

        [[ 0.2766, -0.1873,  0.2416]],

        [[-0.1860,  0.1359, -0.2719]],

        [[ 0.0815,  0.0061, -0.0980]]], grad_fn=<StackBackward>), tensor([[[ 0.5453,  0.1281, -0.2497]],

        [[ 0.9706, -0.3592,  0.4834]],

        [[-0.3706,  0.2681, -0.6189]],

        [[ 0.2029,  0.0121, -0.3028]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.1095,  0.1520, -0.3238,  0.0283,  0.0387, -0.0820]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.1427,  0.0859, -0.2926]],

        [[ 0.1536, -0.2343,  0.0727]],

        [[ 0.1095,  0.1520, -0.3238]],

        [[ 0.0283,  0.0387, -0.0820]]], grad_fn=<StackBackward>), tensor([[[ 0.2386,  0.1646, -0.4102]],

        [[ 0.2636, -0.4828,  0.1889]],

        [[ 0.1967,  0.2848, -0.7155]],

        [[ 0.0735,  0.0702, -0.2859]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.2346,  0.1576, -0.4006, -0.0053,  0.0256, -0.0653]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.1706,  0.0147, -0.0341]],

        [[ 0.1835, -0.3951,  0.2506]],

        [[ 0.2346,  0.1576, -0.4006]],

        [[-0.0053,  0.0256, -0.0653]]], grad_fn=<StackBackward>), tensor([[[ 0.3422,  0.0269, -0.0475]],

        [[ 0.4235, -0.9144,  0.5655]],

        [[ 0.4589,  0.2807, -0.8332]],

        [[-0.0133,  0.0507, -0.1996]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.2774,  0.1639, -0.4460, -0.0228,  0.0086, -0.0369]]],
       grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.2147, -0.0191,  0.0677]],

        [[ 0.2516, -0.4591,  0.3327]],

        [[ 0.2774,  0.1639, -0.4460]],

        [[-0.0228,  0.0086, -0.0369]]], grad_fn=<StackBackward>), tensor([[[ 0.4414, -0.0299,  0.0889]],

        [[ 0.6360, -1.2360,  0.7229]],

        [[ 0.5692,  0.2843, -0.9375]],

        [[-0.0569,  0.0177, -0.1039]]], grad_fn=<StackBackward>))

Upvotes: 6

Pulkit Bansal
Pulkit Bansal

Reputation: 2073

I just verified some of this using code, and its indeed correct that if it's a depth 1 LSTM, then h_n is the same as the last value of the "output". (this will not be true for > 1 depth LSTM though as explained above by @nnnmmm)

So, basically the "output" we get after applying LSTM is not the same as o_t as defined in the documentation, rather it is h_t.

import torch
import torch.nn as nn

torch.manual_seed(0)
model = nn.LSTM( input_size = 1, hidden_size = 50, num_layers  = 1 )
x = torch.rand( 50, 1, 1)
output, (hn, cn) = model(x)

Now one can check that output[-1] and hn both have the same value as follows

tensor([[ 0.1140, -0.0600, -0.0540,  0.1492, -0.0339, -0.0150, -0.0486,  0.0188,
          0.0504,  0.0595, -0.0176, -0.0035,  0.0384, -0.0274,  0.1076,  0.0843,
         -0.0443,  0.0218, -0.0093,  0.0002,  0.1335,  0.0926,  0.0101, -0.1300,
         -0.1141,  0.0072, -0.0142,  0.0018,  0.0071,  0.0247,  0.0262,  0.0109,
          0.0374,  0.0366,  0.0017,  0.0466,  0.0063,  0.0295,  0.0536,  0.0339,
          0.0528, -0.0305,  0.0243, -0.0324,  0.0045, -0.1108, -0.0041, -0.1043,
         -0.0141, -0.1222]], grad_fn=<SelectBackward>)

Upvotes: 6

prosti
prosti

Reputation: 46291

It really depends on a model you use and how you will interpret the model. Output may be:

  • a single LSTM cell hidden state
  • several LSTM cell hidden states
  • all the hidden states outputs

Output, is almost never interpreted directly. If the input is encoded there should be a softmax layer to decode the results.

Note: In language modeling hidden states are used to define the probability of the next word, p(wt+1|w1,...,wt) =softmax(Wht+b).

Upvotes: 4

Jibin Mathew
Jibin Mathew

Reputation: 5102

The output state is the tensor of all the hidden state from each time step in the RNN(LSTM), and the hidden state returned by the RNN(LSTM) is the last hidden state from the last time step from the input sequence. You could check this by collecting all of the hidden states from each step and comparing that to the output state,(provided you are not using pack_padded_sequence).

Upvotes: 3

nnnmmm
nnnmmm

Reputation: 8744

I made a diagram. The names follow the PyTorch docs, although I renamed num_layers to w.

output comprises all the hidden states in the last layer ("last" depth-wise, not time-wise). (h_n, c_n) comprises the hidden states after the last timestep, t = n, so you could potentially feed them into another LSTM.

LSTM diagram

The batch dimension is not included.

Upvotes: 252

Related Questions