Reputation: 27
I have a pytorch model:
model = torch.nn.Sequential(
torch.nn.LSTM(40, 256, 3, batch_first=True),
torch.nn.Linear(256, 256),
torch.nn.ReLU()
)
And for the LSTM layer, I want to retrieve only the last hidden state from the batch to pass through the rest of the layers. Ex:
_, (hidden, _) = lstm(data)
hidden = hidden[-1]
Though, that example only works for a subclassed model. I need to somehow do this on a nn.Sequential() model that way when I save it, it can properly be converted to a tensorflow.js model. The reason I can't make and train this model in tensorflow.js is because I'm trying to implement this repo: Resemblyzer in tensorflow.js while still using the same weights as the pretrained Resemblyzer model which was made in pytorch as a subclassed model. I thought of using the torchvisions.transformations.Lambda() transformation but I would assume that would make it incompatible with tensorflow.js. Is there any way to make this possible while still allowing the model to convert properly?
Upvotes: 1
Views: 5380
Reputation: 1374
Though the answer is provided above, I thought of elaborating on the same as PyTorch LSTM documentation is confusing.
In TF, we directly get the last_state as the output. No further action needed.
Let us check the Torch output of LSTM: There are 2 outputs - a sequence and a tuple. We are interested in the last state so we can ignore the sequence and focus on the tuple. The tuple consists of 2 values - the first is the hidden state of the last cell (of all layers in the LSTM) and the second is the cell state of the last cell (again of all layers in the LSTM). We are interested in the hidden state. So
_, tup = self.bilstm(inp)
We are interested in tup[0]. Let us dig further into this.
The shape of tup[0] is somewhat odd with batch size at the centre. On the left of the batch size is the number of layers in the LSTM (multiply 2 if is biLSTM). On the right is the dimension you have provided while defining the LSTM. You could take the output from the last layer by simply doing a tup[0][-1] which is the answer provided above.
Alternatively if you want to make use of hidden states across layers, you may try something like:
out = tup[0].swapaxes(0,1)
out = out.reshape(*out.shape[:-2], -1)
The first line produces shape of batch_size, num_layers, hidden_size_specified. The second line produces shape of batch_size, num_layers x hidden_size_specified
(For e.g., Let us say, yours is a biLSTM and you have 3 layers and your hiddensize is 100, you could choose to concatenate the output such that you get one vector of 2 x 3 x 100 = 600 dimensions and then run a simple linear layer on top of this to get the output you want.)
There is another way to get the output of the LSTM. We discussed that the first output of an LSTM is a sequence:
sequence, tup = self.bilstm(inp)
This sequence is the output of the LAST hidden layer of the LSTM. It is a sequence because it contains hidden states of EVERY cell in this layer. So its length will be the input sequence length that you have provided. We could choose to take the hidden state of the last element in the sequence by doing a:
#shape of sequence is: batch_size, seq_size, dim
sequence = sequence.swapaxes(0,1)
#shape of sequence is: seq_size, batch_size, dim
sequence = sequence[-1]
#shape of sequence is: batch_size, dim (ie last seq is taken)
Needless to say this will be the same value we got by taking the last layer from tup[0]. Well, not quite! If the LSTM is a biLSTM, then using the sequence approach returns is 2 x hidden_size dim output (which is correct) wheras using the tup[0][-1] approach will give us only hidden_size dim even for a biLSTM. OP's LSTM is a non-biLSTM so both answers hold true.
Upvotes: 0
Reputation: 40768
You could split up your sequential but only doing so in the forward definition of your model on inference. Once defined:
model = nn.Sequential(nn.LSTM(40, 256, 3, batch_first=True),
nn.Linear(256, 256),
nn.ReLU())
You can split it:
>>> lstm, fc = model[0], model[1:]
Then infer in two steps:
>>> out, (hidden, _) = lstm(data)
>>> hidden = hidden[-1]
>>> out = fc(out) # <- or fc(out[-1]) depending on what you want
Upvotes: 2