Reputation: 8439
I'm having trouble understanding the documentation for PyTorch's LSTM module (and also RNN and GRU, which are similar). Regarding the outputs, it says:
Outputs: output, (h_n, c_n)
- output (seq_len, batch, hidden_size * num_directions): tensor containing the output features (h_t) from the last layer of the RNN, for each t. If a torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.
- h_n (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t=seq_len
- c_n (num_layers * num_directions, batch, hidden_size): tensor containing the cell state for t=seq_len
It seems that the variables output
and h_n
both give the values of the hidden state. Does h_n
just redundantly provide the last time step that's already included in output
, or is there something more to it than that?
Upvotes: 117
Views: 56613
Reputation: 1349
This Post is for: understanding the relation bt hidden_size & output_size
+ h_t vs out
+ value visualize
in Rnn.
(disclaimer: This is just my personal understanding. I can be wrong.)
Even though, in math, the output_size of y should be customizable base on the shape of the weight matrix of y.
Math:
It seems that Pytroch decided to treat the output y == hidden state h
-- (use output as hidden state for next time step).
Which is ok, some people use this design for rnn.
So, in pytorch, the output_size == hidden_size
. \
Related:
@note: [h_t vs out]
Though I said, >"Pytroch decided to treat the output y == hidden state h".
The actual output has slight difference.
h_t : This is the final hidden state after processing all time steps.
out : contains output at each time step
You can try & see the values are indeed the same. (Which many other answer posts have done.)
out, h_t = self.rnn(x, h0)
for i in range(0, h_t.shape[1]):
# for every batch, the tensor in the last time step of {y_t} output == the tensor in the last layer of h_t hidden_state
if not torch.equal(out[i, -1], h_t[-1, i]):
Here is one real output from out, h_t = self.rnn(x, h0) # print(out) # print(h_t)
.
Watch this row: [ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098], // << eg see this line, they are same
You can find each row tensor in the h_t
in the last row tensor of out
.
there is no corresponding rows in `output` for layer_0 `hidden_state`
hidde_state shape: (amountOfLayer, batch_size, hidden_size)
if `batch_size` is just 1, then this whole blue block contains just one row vector.
output shape:
(batch_size, sequence_length, hidden_size) ```
>--
tensor([[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.093, 0.082, 0.231, -0.073, 0.143, ..., -0.067, -0.138, 0.016, 0.007, 0.150],
[ 0.098, 0.067, 0.218, -0.055, 0.120, ..., -0.082, -0.096, 0.034, -0.005, 0.149],
[ 0.158, 0.026, 0.242, -0.053, 0.095, ..., -0.100, -0.105, 0.032, -0.022, 0.121],
[ 0.184, 0.051, 0.244, -0.048, 0.115, ..., -0.125, -0.123, 0.023, 0.007, 0.129],
[ 0.147, 0.050, 0.239, -0.025, 0.126, ..., -0.100, -0.135, 0.031, -0.027, 0.131]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.122, 0.021, 0.088, -0.125, 0.089, ..., -0.045, -0.003, 0.064, 0.024, 0.040],
...,
[ 0.071, -0.062, 0.179, -0.167, -0.020, ..., -0.195, -0.090, 0.108, 0.166, 0.113],
[ 0.137, -0.026, 0.259, -0.002, 0.045, ..., -0.156, -0.108, 0.046, 0.049, 0.138],
[ 0.122, -0.019, 0.159, 0.013, 0.085, ..., -0.098, -0.046, 0.006, -0.002, 0.072],
[ 0.133, 0.047, 0.095, 0.011, 0.123, ..., -0.081, -0.055, 0.033, -0.026, 0.107],
[ 0.090, 0.078, 0.068, -0.047, 0.153, ..., -0.052, -0.063, 0.069, -0.030, 0.119]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.064, 0.077, -0.109, -0.184, 0.064, ..., 0.100, -0.053, 0.139, 0.047, 0.058],
[ 0.062, 0.056, -0.136, -0.172, 0.081, ..., 0.073, -0.063, 0.159, 0.054, 0.100],
[ 0.026, 0.006, -0.107, -0.160, 0.105, ..., 0.041, -0.043, 0.172, 0.049, 0.113],
[ 0.087, -0.002, 0.011, -0.106, 0.083, ..., -0.021, -0.031, 0.087, 0.058, 0.047],
[ 0.099, 0.001, 0.045, -0.068, 0.126, ..., -0.043, -0.028, 0.099, 0.008, 0.101]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.046, 0.065, -0.045, -0.082, 0.068, ..., 0.069, -0.037, 0.106, 0.019, 0.102],
[ 0.040, 0.045, -0.063, -0.106, 0.046, ..., 0.068, -0.051, 0.122, 0.023, 0.072],
[ 0.101, 0.050, -0.019, -0.106, 0.068, ..., 0.007, -0.035, 0.119, 0.045, 0.034],
[ 0.126, 0.035, 0.041, -0.078, 0.091, ..., -0.039, -0.030, 0.111, 0.004, 0.087],
[ 0.125, 0.028, 0.066, -0.058, 0.122, ..., -0.053, -0.031, 0.084, -0.016, 0.088]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.090, 0.034, -0.028, -0.124, 0.030, ..., 0.071, -0.059, 0.109, 0.089, 0.058],
[ 0.109, 0.058, -0.057, -0.142, 0.042, ..., 0.064, -0.043, 0.122, 0.082, 0.047],
[ 0.091, 0.036, -0.018, -0.094, 0.039, ..., 0.036, -0.003, 0.126, 0.028, 0.066],
[ 0.122, 0.046, 0.041, -0.068, 0.088, ..., -0.023, -0.023, 0.114, 0.007, 0.052],
[ 0.127, 0.034, 0.056, -0.061, 0.126, ..., -0.047, -0.022, 0.102, -0.010, 0.079]],
...,
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[-0.042, -0.062, 0.058, -0.124, 0.102, ..., 0.064, -0.063, 0.070, 0.269, 0.059],
[ 0.005, -0.011, -0.033, -0.105, 0.090, ..., 0.051, -0.037, 0.088, 0.175, 0.086],
[ 0.050, -0.022, -0.007, -0.003, 0.033, ..., 0.051, -0.002, 0.010, 0.080, 0.041],
[ 0.107, 0.045, 0.001, -0.022, 0.127, ..., -0.013, -0.072, 0.049, 0.027, 0.059],
[ 0.121, 0.076, 0.010, -0.044, 0.148, ..., -0.042, -0.076, 0.044, 0.004, 0.082]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.106, 0.021, 0.173, -0.125, 0.190, ..., -0.045, -0.106, 0.089, 0.073, 0.136],
[ 0.077, 0.070, 0.030, -0.137, 0.173, ..., 0.030, -0.067, 0.129, 0.023, 0.162],
[ 0.067, 0.056, 0.031, -0.074, 0.039, ..., 0.057, -0.016, 0.059, 0.012, 0.073],
[ 0.095, 0.077, 0.036, -0.066, 0.120, ..., 0.021, -0.044, 0.117, -0.004, 0.073],
[ 0.125, 0.057, 0.063, -0.060, 0.109, ..., -0.030, -0.051, 0.102, -0.016, 0.085]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.048, 0.031, -0.103, -0.133, 0.037, ..., 0.069, -0.058, 0.122, 0.030, 0.073],
[ 0.043, -0.006, -0.082, -0.113, 0.040, ..., 0.040, -0.053, 0.122, 0.032, 0.083],
[ 0.088, 0.013, -0.004, -0.104, 0.083, ..., -0.036, -0.046, 0.100, 0.048, 0.055],
[ 0.115, 0.017, 0.061, -0.080, 0.102, ..., -0.054, -0.037, 0.096, -0.002, 0.102],
[ 0.113, 0.023, 0.075, -0.052, 0.118, ..., -0.050, -0.031, 0.078, -0.022, 0.093]],
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.091, -0.014, 0.236, -0.042, 0.136, ..., -0.139, -0.112, -0.028, 0.070, 0.103],
[ 0.095, -0.008, 0.239, -0.029, 0.136, ..., -0.127, -0.120, -0.014, 0.055, 0.126],
[ 0.123, 0.026, 0.213, -0.024, 0.147, ..., -0.134, -0.105, 0.010, 0.061, 0.121],
[ 0.127, 0.038, 0.211, -0.009, 0.147, ..., -0.092, -0.124, 0.040, 0.005, 0.114],
[ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098]], // << eg see this line, they are same
[[ 0.057, 0.105, -0.045, -0.055, 0.214, ..., -0.022, -0.023, 0.018, 0.058, -0.041],
[ 0.069, 0.013, 0.026, -0.040, 0.123, ..., -0.054, 0.014, 0.114, 0.029, 0.067],
[ 0.142, 0.028, 0.018, -0.053, 0.123, ..., -0.046, -0.019, 0.104, 0.024, 0.047],
[ 0.159, 0.032, 0.028, -0.059, 0.115, ..., -0.034, -0.002, 0.121, -0.036, 0.056],
[ 0.132, 0.046, 0.054, -0.046, 0.109, ..., -0.043, -0.005, 0.098, -0.038, 0.072],
...,
[ 0.152, -0.010, 0.118, -0.057, 0.071, ..., -0.202, 0.018, 0.039, 0.035, 0.105],
[ 0.118, -0.042, 0.180, -0.128, 0.109, ..., -0.099, -0.110, 0.044, 0.063, 0.076],
[ 0.073, -0.049, 0.173, -0.136, 0.087, ..., -0.036, -0.123, 0.002, 0.133, 0.087],
[ 0.069, -0.013, 0.158, -0.002, 0.055, ..., -0.063, -0.058, 0.019, 0.025, 0.139],
[ 0.124, 0.043, 0.126, 0.003, 0.102, ..., -0.062, -0.050, 0.027, 0.027, 0.043]]], device='cuda:0', grad_fn=<CudnnRnnBackward0>)
tensor([[[ 0.183, -0.012, 0.073, -0.071, -0.109, ..., 0.137, -0.035, -0.130, -0.083, 0.095],
[ 0.180, 0.006, -0.001, -0.072, -0.091, ..., 0.107, -0.048, -0.102, -0.115, 0.042],
[ 0.172, -0.030, -0.035, -0.092, -0.117, ..., 0.134, -0.004, -0.097, -0.090, 0.047],
[ 0.174, -0.006, -0.010, -0.082, -0.100, ..., 0.104, -0.027, -0.120, -0.122, 0.045],
[ 0.185, 0.012, -0.005, -0.058, -0.087, ..., 0.119, -0.025, -0.112, -0.108, 0.044],
...,
[ 0.133, 0.020, -0.024, -0.084, -0.078, ..., 0.120, -0.073, -0.125, -0.093, 0.026],
[ 0.162, -0.003, -0.011, -0.075, -0.095, ..., 0.102, -0.036, -0.110, -0.100, 0.031],
[ 0.172, -0.020, -0.010, -0.087, -0.097, ..., 0.100, -0.028, -0.122, -0.131, 0.044],
[ 0.165, 0.015, 0.037, -0.078, -0.044, ..., 0.125, -0.020, -0.113, -0.134, 0.039],
[ 0.124, 0.012, -0.002, -0.105, 0.006, ..., 0.146, -0.070, -0.155, -0.155, 0.038]],
[[ 0.147, 0.050, 0.239, -0.025, 0.126, ..., -0.100, -0.135, 0.031, -0.027, 0.131],
[ 0.090, 0.078, 0.068, -0.047, 0.153, ..., -0.052, -0.063, 0.069, -0.030, 0.119],
[ 0.099, 0.001, 0.045, -0.068, 0.126, ..., -0.043, -0.028, 0.099, 0.008, 0.101],
[ 0.125, 0.028, 0.066, -0.058, 0.122, ..., -0.053, -0.031, 0.084, -0.016, 0.088],
[ 0.127, 0.034, 0.056, -0.061, 0.126, ..., -0.047, -0.022, 0.102, -0.010, 0.079],
...,
[ 0.121, 0.076, 0.010, -0.044, 0.148, ..., -0.042, -0.076, 0.044, 0.004, 0.082],
[ 0.125, 0.057, 0.063, -0.060, 0.109, ..., -0.030, -0.051, 0.102, -0.016, 0.085],
[ 0.113, 0.023, 0.075, -0.052, 0.118, ..., -0.050, -0.031, 0.078, -0.022, 0.093],
[ 0.118, 0.049, 0.117, -0.012, 0.138, ..., -0.074, -0.069, 0.046, -0.004, 0.098], // << eg see this line, they are same
[ 0.124, 0.043, 0.126, 0.003, 0.102, ..., -0.062, -0.050, 0.027, 0.027, 0.043]]], device='cuda:0', grad_fn=<CudnnRnnBackward0>)
Following is the a design of RNN,
for complete code, see https://www.youtube.com/watch?v=0_PgWWmauHk
-> https://github.com/patrickloeber/pytorch-examples/blob/master/rnn-lstm-gru/main.py
This uses rnn to predict mnist digits.
Uses Row-wise Flattening of a 2d mnist matrix.
With
input_size = 28 # H_in - input_size – The number of expected features in the input x # size of each vector (row or col wise)
sequence_length = 28 # L - sequence length or the number of time steps
Uses 2 stacked layers in rnn.
The output_size is transformed after rnn, at
self.fc = nn.Linear(hidden_size, num_classes)
# @shape: (batch_size, 128) -> (batch_size, 10)
(//...)
out = self.fc(out)
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
# RNN — PyTorch 2.5 documentation
# https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, nonlinearity="tanh", bidirectional=False)
# @shape: 28, 128, 2
# or:
# self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
# self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
# @shape: (batch_size, 128) -> (batch_size, 10)
def forward(self, x: Tensor):
# Set initial hidden states (and cell states for LSTM)
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# @nota: n = batch_size = 100
# @shape: x: (batch_size, sequence_length, input_size) = (n, 28, 28)
# @shape: h0: (2, n, 128)
# or with:
# c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate RNN // em, no that self loop for recurrent now ..
out, h_t = self.rnn(x, h0)
out: Tensor
# @shape: out: (batch_size, seq_length, hidden_size) = (n, 28, 128)
# (h_t): Shape: (num_layers, batch_size, hidden_size)
# print(">--")
# print(out)
# print(h_t)
# for i in range(0, h_t.shape[1]):
# # for every batch, the tensor in the last time step of {y_t} output == the tensor in the last layer of h_t hidden_state
# if not torch.equal(out[i, -1], h_t[-1, i]):
# raise ValueError("")
# or:
# out, _ = self.lstm(x, (h0,c0))
# Decode the hidden state of the last time step # The `-1` index refers to the last element along the second dimension (sequence length in this case).
out = out[:, -1, :]
# @shape: out: (n, 128) # output of the 28th/last time step
out = self.fc(out)
# @shape: out: (n, 10)
return out
Upvotes: 0
Reputation: 71
In Pytorch, the output parameter gives the output of each individual LSTM cell in the last layer of the LSTM stack, while hidden state and cell state give the output of each hidden cell and cell state in the LSTM stack in every layer.
import torch.nn as nn
torch.manual_seed(1)
inputs = [torch.randn(1, 3) for _ in range(5)] # indicates that there are 5 sequences to be given as inputs and (1,3) indicates that there is 1 layer with 3 cells
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3)) #initializing h and c values to be of dimensions (1, 1, 3) which indicates there is (1 * 1) - num_layers * num_directions, with batch size of 1 and projection size of 3.
#Since there is only 1 batch in input, h and c can also have only one batch of data for initialization and the number of cells in both input and output should also match.
lstm = nn.LSTM(3, 3) #implying both input and output are 3 dimensional data
for i in inputs:
out, hidden = lstm(i.view(1, 1, -1), hidden)
print('out:', out)
print('hidden:', hidden)
Output
out: tensor([[[-0.1124, -0.0653, 0.2808]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.1124, -0.0653, 0.2808]]], grad_fn=<StackBackward>), tensor([[[-0.2883, -0.2846, 2.0720]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.1675, -0.0376, 0.4402]]], grad_fn=<StackBackward>)
hidden: (tensor([[[ 0.1675, -0.0376, 0.4402]]], grad_fn=<StackBackward>), tensor([[[ 0.4394, -0.1226, 1.5611]]], grad_fn=<StackBackward>))
out: tensor([[[0.3699, 0.0150, 0.1429]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.3699, 0.0150, 0.1429]]], grad_fn=<StackBackward>), tensor([[[0.8432, 0.0618, 0.9413]]], grad_fn=<StackBackward>))
out: tensor([[[0.1795, 0.0296, 0.2957]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.1795, 0.0296, 0.2957]]], grad_fn=<StackBackward>), tensor([[[0.4541, 0.1121, 0.9320]]], grad_fn=<StackBackward>))
out: tensor([[[0.1365, 0.0596, 0.3931]]], grad_fn=<StackBackward>)
hidden: (tensor([[[0.1365, 0.0596, 0.3931]]], grad_fn=<StackBackward>), tensor([[[0.3430, 0.1948, 1.0255]]], grad_fn=<StackBackward>))
Multi-Layered LSTM
import torch.nn as nn
torch.manual_seed(1)
num_layers = 2
inputs = [torch.randn(1, 3) for _ in range(5)]
hidden = (torch.randn(2, 1, 3),
torch.randn(2, 1, 3))
lstm = nn.LSTM(input_size=3, hidden_size=3, num_layers=2)
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
print('out:', out)
print('hidden:', hidden)
Output
out: tensor([[[-0.0819, 0.1214, -0.2586]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.2625, 0.4415, -0.4917]],
[[-0.0819, 0.1214, -0.2586]]], grad_fn=<StackBackward>), tensor([[[-2.5740, 0.7832, -0.9211]],
[[-0.2803, 0.5175, -0.5330]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1298, 0.2797, -0.0882]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.3818, 0.3306, -0.3020]],
[[-0.1298, 0.2797, -0.0882]]], grad_fn=<StackBackward>), tensor([[[-2.3980, 0.6347, -0.6592]],
[[-0.3643, 0.9301, -0.1326]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1630, 0.3187, 0.0728]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.5612, 0.3134, -0.0782]],
[[-0.1630, 0.3187, 0.0728]]], grad_fn=<StackBackward>), tensor([[[-1.7555, 0.6882, -0.3575]],
[[-0.4571, 1.2094, 0.1061]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1723, 0.3274, 0.1546]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.5112, 0.1597, -0.0901]],
[[-0.1723, 0.3274, 0.1546]]], grad_fn=<StackBackward>), tensor([[[-1.4417, 0.5892, -0.2489]],
[[-0.4940, 1.3620, 0.2255]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1847, 0.2968, 0.1333]]], grad_fn=<StackBackward>)
hidden: (tensor([[[-0.3256, 0.3217, -0.1899]],
[[-0.1847, 0.2968, 0.1333]]], grad_fn=<StackBackward>), tensor([[[-1.7925, 0.6096, -0.4432]],
[[-0.5147, 1.4031, 0.2014]]], grad_fn=<StackBackward>))
Bi-Directional Multi-Layered LSTM
import torch.nn as nn
torch.manual_seed(1)
num_layers = 2
is_bidirectional = True
inputs = [torch.randn(1, 3) for _ in range(5)]
hidden = (torch.randn(4, 1, 3),
torch.randn(4, 1, 3)) #4 -> (2 * 2) -> num_layers * num_directions
lstm = nn.LSTM(input_size=3, hidden_size=3, num_layers=2, bidirectional=is_bidirectional)
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
print('out:', out)
print('hidden:', hidden)
# output dim -> (seq_len, batch, num_directions * hidden_size) -> (5, 1, 2*3)
# hidden dim -> (num_layers * num_directions, batch, hidden_size) -> (2 * 2, 1, 3)
# cell state dim -> (num_layers * num_directions, batch, hidden_size) -> (2 * 2, 1, 3)
Output
out: tensor([[[-0.4620, 0.1115, -0.1087, 0.1646, 0.0173, -0.2196]]],
grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.5187, 0.2656, -0.2543]],
[[ 0.4175, 0.0539, 0.0633]],
[[-0.4620, 0.1115, -0.1087]],
[[ 0.1646, 0.0173, -0.2196]]], grad_fn=<StackBackward>), tensor([[[ 1.1546, 0.4012, -0.4119]],
[[ 0.7999, 0.2632, 0.2587]],
[[-1.4196, 0.2075, -0.3148]],
[[ 0.6605, 0.0243, -0.5783]]], grad_fn=<StackBackward>))
out: tensor([[[-0.1860, 0.1359, -0.2719, 0.0815, 0.0061, -0.0980]]],
grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.2945, 0.0842, -0.1580]],
[[ 0.2766, -0.1873, 0.2416]],
[[-0.1860, 0.1359, -0.2719]],
[[ 0.0815, 0.0061, -0.0980]]], grad_fn=<StackBackward>), tensor([[[ 0.5453, 0.1281, -0.2497]],
[[ 0.9706, -0.3592, 0.4834]],
[[-0.3706, 0.2681, -0.6189]],
[[ 0.2029, 0.0121, -0.3028]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.1095, 0.1520, -0.3238, 0.0283, 0.0387, -0.0820]]],
grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.1427, 0.0859, -0.2926]],
[[ 0.1536, -0.2343, 0.0727]],
[[ 0.1095, 0.1520, -0.3238]],
[[ 0.0283, 0.0387, -0.0820]]], grad_fn=<StackBackward>), tensor([[[ 0.2386, 0.1646, -0.4102]],
[[ 0.2636, -0.4828, 0.1889]],
[[ 0.1967, 0.2848, -0.7155]],
[[ 0.0735, 0.0702, -0.2859]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.2346, 0.1576, -0.4006, -0.0053, 0.0256, -0.0653]]],
grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.1706, 0.0147, -0.0341]],
[[ 0.1835, -0.3951, 0.2506]],
[[ 0.2346, 0.1576, -0.4006]],
[[-0.0053, 0.0256, -0.0653]]], grad_fn=<StackBackward>), tensor([[[ 0.3422, 0.0269, -0.0475]],
[[ 0.4235, -0.9144, 0.5655]],
[[ 0.4589, 0.2807, -0.8332]],
[[-0.0133, 0.0507, -0.1996]]], grad_fn=<StackBackward>))
out: tensor([[[ 0.2774, 0.1639, -0.4460, -0.0228, 0.0086, -0.0369]]],
grad_fn=<CatBackward>)
hidden: (tensor([[[ 0.2147, -0.0191, 0.0677]],
[[ 0.2516, -0.4591, 0.3327]],
[[ 0.2774, 0.1639, -0.4460]],
[[-0.0228, 0.0086, -0.0369]]], grad_fn=<StackBackward>), tensor([[[ 0.4414, -0.0299, 0.0889]],
[[ 0.6360, -1.2360, 0.7229]],
[[ 0.5692, 0.2843, -0.9375]],
[[-0.0569, 0.0177, -0.1039]]], grad_fn=<StackBackward>))
Upvotes: 6
Reputation: 2073
I just verified some of this using code, and its indeed correct that if it's a depth 1 LSTM, then h_n is the same as the last value of the "output". (this will not be true for > 1 depth LSTM though as explained above by @nnnmmm)
So, basically the "output" we get after applying LSTM is not the same as o_t as defined in the documentation, rather it is h_t.
import torch
import torch.nn as nn
torch.manual_seed(0)
model = nn.LSTM( input_size = 1, hidden_size = 50, num_layers = 1 )
x = torch.rand( 50, 1, 1)
output, (hn, cn) = model(x)
Now one can check that output[-1]
and hn
both have the same value as follows
tensor([[ 0.1140, -0.0600, -0.0540, 0.1492, -0.0339, -0.0150, -0.0486, 0.0188,
0.0504, 0.0595, -0.0176, -0.0035, 0.0384, -0.0274, 0.1076, 0.0843,
-0.0443, 0.0218, -0.0093, 0.0002, 0.1335, 0.0926, 0.0101, -0.1300,
-0.1141, 0.0072, -0.0142, 0.0018, 0.0071, 0.0247, 0.0262, 0.0109,
0.0374, 0.0366, 0.0017, 0.0466, 0.0063, 0.0295, 0.0536, 0.0339,
0.0528, -0.0305, 0.0243, -0.0324, 0.0045, -0.1108, -0.0041, -0.1043,
-0.0141, -0.1222]], grad_fn=<SelectBackward>)
Upvotes: 6
Reputation: 46291
It really depends on a model you use and how you will interpret the model. Output may be:
Output, is almost never interpreted directly. If the input is encoded there should be a softmax layer to decode the results.
Note: In language modeling hidden states are used to define the probability of the next word, p(wt+1|w1,...,wt) =softmax(Wht+b).
Upvotes: 4
Reputation: 5102
The output state is the tensor of all the hidden state from each time step in the RNN(LSTM), and the hidden state returned by the RNN(LSTM) is the last hidden state from the last time step from the input sequence. You could check this by collecting all of the hidden states from each step and comparing that to the output state,(provided you are not using pack_padded_sequence).
Upvotes: 3
Reputation: 8744
I made a diagram. The names follow the PyTorch docs, although I renamed num_layers
to w
.
output
comprises all the hidden states in the last layer ("last" depth-wise, not time-wise). (h_n, c_n)
comprises the hidden states after the last timestep, t = n, so you could potentially feed them into another LSTM.
The batch dimension is not included.
Upvotes: 252