I. A
I. A

Reputation: 2312

How to reset the state of a GRU in tensorflow after every epoch

I am using the tensorflow GRU cell to implement an RNN. I am using the aforementioned with videos that range for maximum 5 mins. Therefore, since the next state is fed automatically into the GRU, how can I reset manually the state of the RNN after each epoch. In other words, I want the initial state at the beginning of the training to be always 0. Here is a snippet for my code:

with tf.variable_scope('GRU'):
    latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])

    cell = tf.nn.rnn_cell.GRUCell(cell_size)   
    H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)  
    H = tf.reshape(H, [batch_size, cell_size]) 
....

Any help is much appreciated!

Upvotes: 1

Views: 1686

Answers (1)

Maxim
Maxim

Reputation: 53758

Use initial_state argument of tf.nn.dynamic_rnn:

initial_state: (optional) An initial state for the RNN. If cell.state_size is an integer, this must be a Tensor of appropriate type and shape [batch_size, cell.state_size]. If cell.state_size is a tuple, this should be a tuple of tensors having shapes [batch_size, s] for s in cell.state_size.

An adapted example from the documentation:

# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

Also note that despite initial_state not being a placeholder, you can also feed the value to it. So if wish to preserve the state within an epoch, but start with a zero at the beginning of the epoch, you can do it like this:

# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)

# Start with a zero vector and update it 
cur_state = zero_state
for batch in get_batches():
  cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})

Upvotes: 1

Related Questions