LJKS
LJKS

Reputation: 919

Output of Tensorflow LSTM-Cell

I've got a question on Tensorflow LSTM-Implementation. There are currently several implementations in TF, but I use:

cell = tf.contrib.rnn.BasicLSTMCell(n_units)

Then to get my output I call:

 rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, x,
                        initial_state=initial_state, time_major=False)

I expect rnn_outputs to be of shape (batch_size, time_steps, n_units, input_length) as I have not specified another output size. Documentation of nn.dynamic_rnn tells me that output is of shape (batch_size, input_length, cell.output_size). The documentation of tf.contrib.rnn.BasicLSTMCell does have a property output_size, which is defaulted to n_units (the amount of LSTM-cells I use).

So does each LSTM-Cell only output a scalar for every given timestep? I would expect it to output a vector of the length of the input vector. This seems not to be the case from how I understand it right now, so I am confused. Can you tell me whether that's the case or how I could change it to output a vector of size of the input vector per single lstm-cell maybe?

Upvotes: 7

Views: 4093

Answers (1)

Animesh Karnewar
Animesh Karnewar

Reputation: 436

I think the primary confusion is on the terminology of the LSTM cell's argument: num_units. Unfortunately it doesn't mean, as the name suggests, "the no. of LSTM cells" that should be equal to your time-steps. They actually correspond to the number of dimensions in the hidden state (cell state + hidden state vector). The call to dynamic_rnn() returns a tensor of shape: [batch_size, time_steps, output_size] where,

(Please note this) output_size = num_units; if (num_proj = None) in the lstm cell
where as, output_size = num_proj; if it is defined.

Now, typically, you will extract the last time_step's result and project it to the size of output dimensions using a mat-mul + biases operation manually, or use the num_proj argument in the LSTM cell.
I have been through the same confusion and had to look really deep to get it cleared. Hope this answer clears some of it.

Upvotes: 6

Related Questions