Lenar Hoyt
Lenar Hoyt

Reputation: 6159

How can I pass the previous state of a tuple-based tf.nn.MultiRNNCell to the next sess.run() call in TensorFlow?

I am using a stack of RNNs built with tf.nn.MultiRNNCell and I want to pass the final_state to the next graph invocation. Since tuples are not supported in the feed dictionary, is stacking the cell states and slicing the input to yield a tuple at the beginning of the graph the only way of accomplishing that, or is there some functionality in TensorFlow that allows to do that?

Upvotes: 1

Views: 1536

Answers (2)

user1506145
user1506145

Reputation: 5286

I would try to store the whole state in a tensor with the following shape:

init_state = np.zeros((num_layers, 2, batch_size, state_size))

Then feed it and unpack it in your graph

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
      [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
      for idx in range(num_layers)]
)

Upvotes: 0

Eugene Brevdo
Eugene Brevdo

Reputation: 899

Suppose you have 3 RNNCells in your MultiRNNCell and each is a LSTMCell with an LSTMStateTuple state. You must replicate this structure with placeholders:

lstm0_c = tf.placeholder(...)
lstm0_h = tf.placeholder(...)
lstm1_c = tf.placeholder(...)
lstm1_h = tf.placeholder(...)
lstm2_c = tf.placeholder(...)
lstm2_h = tf.placeholder(...)

initial_state = tuple(
  tf.nn.rnn_cell.LSTMStateTuple(lstm0_c, lstm0_h),
  tf.nn.rnn_cell.LSTMStateTuple(lstm1_c, lstm1_h),
  tf.nn.rnn_cell.LSTMStateTuple(lstm2_c, lstm2_h))

...

sess.run(..., feed_dict={
  lstm0_c: final_state[0].c,
  lstm0_h: final_state[0].h,
  lstm1_c: final_state[1].c,
  lstm1_h: final_state[1].h,
  ...
})

If you have N stacked LSTM layers you can programmatically create the placeholders and feed_dict with for loops.

Upvotes: 4

Related Questions