Boris
Boris

Reputation: 886

what does nn.Linear() do in pytorch's last, and why is it necessary?

I am working with some code that trains an lstm to generate sequences. After training the model, the lstm() method is called:

x = some_input
lstm_output, (h_n, c_n) = lstm(x, hc) 
funcc = nn.Linear(in_features=lstm_num_hidden,
                  output_features=vocab_size,
                  bias=True)
func_output = func(lstm_output)

I've looked at the documentation for nn.Linear() but I still don't understand what this transformation is doing and why it is necessary. If the lstm has already been trained, then the output it gives should already have a pre-established dimensionality. This output (lstm_output) would be the generated sequence, or in my case an array of vectors. Am I missing something here?

Upvotes: 1

Views: 1938

Answers (1)

Wasi Ahmad
Wasi Ahmad

Reputation: 37691

Here, the Linear layer is transforming the hidden state representations (lstm_output) produced by the LSTM into a vector of size vocab_size. Your understanding is perhaps wrong. The Linear layer should be trained along with the LSTM.

And I guess you are trying to generate a sequence of tokens (words), so the Linear layer should be followed by a Softmax operation to predict a probability distribution over the vocabulary.

Upvotes: 5

Related Questions