Reputation: 6159
I am using a stack of RNNs built with tf.nn.MultiRNNCell
and I want to pass the final_state
to the next graph invocation. Since tuples are not supported in the feed dictionary, is stacking the cell states and slicing the input to yield a tuple at the beginning of the graph the only way of accomplishing that, or is there some functionality in TensorFlow that allows to do that?
Upvotes: 1
Views: 1536
Reputation: 5286
I would try to store the whole state in a tensor with the following shape:
init_state = np.zeros((num_layers, 2, batch_size, state_size))
Then feed it and unpack it in your graph
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
for idx in range(num_layers)]
)
Upvotes: 0
Reputation: 899
Suppose you have 3 RNNCells in your MultiRNNCell and each is a LSTMCell with an LSTMStateTuple state. You must replicate this structure with placeholders:
lstm0_c = tf.placeholder(...)
lstm0_h = tf.placeholder(...)
lstm1_c = tf.placeholder(...)
lstm1_h = tf.placeholder(...)
lstm2_c = tf.placeholder(...)
lstm2_h = tf.placeholder(...)
initial_state = tuple(
tf.nn.rnn_cell.LSTMStateTuple(lstm0_c, lstm0_h),
tf.nn.rnn_cell.LSTMStateTuple(lstm1_c, lstm1_h),
tf.nn.rnn_cell.LSTMStateTuple(lstm2_c, lstm2_h))
...
sess.run(..., feed_dict={
lstm0_c: final_state[0].c,
lstm0_h: final_state[0].h,
lstm1_c: final_state[1].c,
lstm1_h: final_state[1].h,
...
})
If you have N stacked LSTM layers you can programmatically create the placeholders and feed_dict with for loops.
Upvotes: 4