lenhhoxung
lenhhoxung

Reputation: 2746

States in the tensorflow static rnn

I'm trying to work with RNN using Tensorflow. I use the following function from this repos:

def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, timesteps, n_input)
    # Required shape: 'timesteps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'timesteps' tensors of shape (batch_size, n_input)
    x = tf.unstack(x, timesteps, 1)

    # Define a lstm cell with tensorflow
    lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

    # Get lstm cell output
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
return tf.matmul(outputs[-1], weights['out']) + biases['out']

I understand that outputs is a list containing intermediate outputs in the unrolled neural network. I can verify that len(outputs) equals timesteps. However, I wonder why len(states) equals 2. I think I should contain only the final state of the network. Could you please help explain? Thanks.

Upvotes: 0

Views: 619

Answers (1)

Akshay Agrawal
Akshay Agrawal

Reputation: 922

To confirm the discussion in the comments: when constructing a static RNN using BasicLSTMCell, state is a two-tuple of (c, h), where c is the final cell state and h is the final hidden state. The final cell hidden state is in fact equal to the final output in outputs. You can corroborate this by reading the source code (see BasicLSTMCell's call method).

Upvotes: 1

Related Questions