Martin Studer
Martin Studer

Reputation: 2321

Structure of initial state for stacked LSTM

What is the required structure for an initial state on a multilayer/stacked RNN in TensorFlow (1.13.1) using the tf.keras.layers.RNN API?

I tried the following:

lstm_cell_sizes = [256, 256, 256]
lstm_cells = [tf.keras.layers.LSTMCell(size) for size in lstm_cell_sizes]

state_init = [tf.placeholder(tf.float32, shape=[None] + cell.state_size) for cell in lstm_cells]

tf.keras.layers.RNN(lstm_cells, ...)(inputs, initial_state=state_init)

This results in:

ValueError: Could not pack sequence. Structure had 6 elements, but flat_sequence had 3 elements.  Structure: ([256, 256], [256, 256], [256, 256]), flat_sequence: [<tf.Tensor 'player/Placeholder:0' shape=(?, 256, 256) dtype=float32>, <tf.Tensor 'player/Placeholder_1:0' shape=(?, 256, 256) dtype=float32>, <tf.Tensor 'player/Placeholder_2:0' shape=(?, 256, 256) dtype=float32>].

If I change state_init to be a flattened list of tensors with shape [None, 256] instead, I am getting:

ValueError: An `initial_state` was passed that is not compatible with `cell.state_size`. Received `state_spec`=[InputSpec(shape=(None, 256), ndim=2), InputSpec(shape=(None, 256), ndim=2), InputSpec(shape=(None, 256), ndim=2)]; however `cell.state_size` is [[256, 256], [256, 256], [256, 256]]

The Tensorflow RNN docs are fairly vague on this:

"You can specify the initial state of RNN layers symbolically by calling them with the keyword argument initial_state. The value of initial_state should be a tensor or list of tensors representing the initial state of the RNN layer."

Upvotes: 5

Views: 1398

Answers (2)

Bon Ryu
Bon Ryu

Reputation: 718

In TF2, the APIs for tf.keras.layers.RNN or tf.keras.layers.LSTM or tf.keras.layers.GRU show these layers' call() function all take the initial_state= parameter. Here is the description

initial_state: List of initial state tensors to be passed to the first call of the cell (optional, defaults to None which causes creation of zero-filled initial state tensors).

Below, is a modified version of the small example from tf.keras.layers.StackedRNNCells,

batch_size = 3
sentence_max_length = 5
n_hidden = 2  # number of hidden layers
new_shape = (batch_size, sentence_max_length, n_hidden)
x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)

rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(n_hid_layers)]
stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)

# Return the state to serve the dual purpose of 
#  1) verifying that the returned state is a list and 
#  2) to re-use the returned state as the initial_state 
#     in a subsequent call to the lstm_layer
lstm_layer = tf.keras.layers.RNN(stacked_lstm,return_state=True,return_sequences=False)

# In the first call, init_state=None which leads to using zero filled initial states
# The call to lstm_layer returns a list. result[0] is the output of the lstm,
#  result[1] and result[2] are the states of the 1st and 2nd layer of the
#  lstm, respectively
result = lstm_layer(x)  
result2 = lstm_layer(result[0], initial_state=result[1:])

Here is a snapshot of the what the variables look like in my Pycharm python console:

enter image description here

Upvotes: 0

Richard
Richard

Reputation: 61489

I believe this how you do it in TF2:

import tensorflow.compat.v2 as tf #If you have a newer version of TF1
#import tensorflow as tf          #If you have TF2

sentence_max_length = 5
batch_size = 3
n_hidden = 2
x = tf.constant(np.reshape(np.arange(30),(batch_size,sentence_max_length, n_hidden)), dtype = tf.float32)

stacked_lstm = tf.keras.layers.StackedRNNCells([tf.keras.layers.LSTMCell(128) for _ in range(2)])

lstm_layer = tf.keras.layers.RNN(stacked_lstm,return_state=False,return_sequences=False)

result = lstm_layer(x)
print(result)

Upvotes: 1

Related Questions