Reputation: 2312
I am using the tensorflow GRU cell to implement an RNN. I am using the aforementioned with videos that range for maximum 5 mins. Therefore, since the next state is fed automatically into the GRU, how can I reset manually the state of the RNN after each epoch. In other words, I want the initial state at the beginning of the training to be always 0. Here is a snippet for my code:
with tf.variable_scope('GRU'):
latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])
cell = tf.nn.rnn_cell.GRUCell(cell_size)
H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)
H = tf.reshape(H, [batch_size, cell_size])
....
Any help is much appreciated!
Upvotes: 1
Views: 1686
Reputation: 53758
Use initial_state
argument of tf.nn.dynamic_rnn
:
initial_state
: (optional) An initial state for the RNN. Ifcell.state_size
is an integer, this must be a Tensor of appropriate type and shape[batch_size, cell.state_size]
. Ifcell.state_siz
e is a tuple, this should be a tuple of tensors having shapes[batch_size, s] for s in cell.state_size
.
An adapted example from the documentation:
# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
Also note that despite initial_state
not being a placeholder, you can also feed the value to it. So if wish to preserve the state within an epoch, but start with a zero at the beginning of the epoch, you can do it like this:
# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)
# Start with a zero vector and update it
cur_state = zero_state
for batch in get_batches():
cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})
Upvotes: 1