Reputation: 164
I am trying to find the best way to pass the LSTM state between batches. I have searched everything but I could not find a solution for the current implementation. Imagine I have something like:
cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)
Now I would like to pass the new_state
in each batch efficiently, so without storing it back to memory and then re-feed to tf using feed_dict
. To be more precise, all the solutions I found use sess.run
to evaluate new_state
and feed-dict
to pass it into init_state
. Is there any way to do so without having the bottleneck of using feed-dict
?
I think I should use tf.assign
in some way but the doc is incomplete and I could not find any workaround.
I want to thank everybody that will ask in advance.
Cheers,
Francesco Saverio
All the others answers that I found on stack overflow works for older version or use the 'feed-dict' method to pass the new state. For instance:
1) TensorFlow: Remember LSTM state for next batch (stateful LSTM) This works by using 'feed-dict' to feed the state placeholder and I want to avoid that
2) Tensorflow - LSTM state reuse within batch This does not work with the state turple
3) Saving LSTM RNN state between runs in Tensorflow Same here
Upvotes: 3
Views: 1020
Reputation: 3633
LSTMStateTuple
is nothing more than a tuple of output and hidden state. tf.assign
creates an operation that when run, assigns a value stored in a tensor to a variable (if you have specific questions, please ask so that docs can be improved). You can use the solution with tf.assign
by retrieving the hidden state tensor using from the tuple using the c
attribute of the tuple (assuming you want the hidden state) - new_state.c
Here is a complete self-contained example on a toy problem: https://gist.github.com/iganichev/632b425fed0263d0274ec5b922aa3b2f
Upvotes: 4