Gokul NC
Gokul NC

Reputation: 1241

Back Propagation in time for tf.nn.dynamic_rnn for sequential input (from batch)

I have a code like this:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple = True)
c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c], "c_in")
h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h], "h_in")
rnn_state_in = (c_in, h_in)
rnn_in = tf.expand_dims(previous_layer, [0])
sequence_length = #size of my batch
rnn_state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(lstm_cell,
                                                    rnn_in,
                                                    initial_state = rnn_state_in,
                                                    sequence_length = sequence_length,
                                                    time_major = False)
lstm_c, lstm_h = lstm_state
rnn_out = tf.reshape(lstm_outputs, [-1, 256])

Here, I use dynamic_rnn to simulate the time steps from the batch. While each forward pass, I can get lstm_c, lstm_h which I can store anywhere outside.

So, suppose I have done a forward pass for N items in a sequence in my network and have the final cell state and hidden state provided from the dynamic_rnn. Now, to perform back propagation, what should be my input to the LSTMs?

By default, does backprop happen across time steps in dynamic_rnn?

(say, no. of time steps = batch_size=N)

So is it enough for me to provide the input as:

sess.run(_train_op, feed_dict = {_state: np.vstack(batch_states),
                                            ...
                                            c_in: prev_rnn_state[0],
                                            h_in: prev_rnn_state[1]
})

(where prev_rnn_state is a tuple of cell state, hidden state, which I got from the dynamic_rnn from forward propagation for the previous batch.)

Or do I have unroll the LSTM layer across time series explicitly and train it by providing a vector of the cell states and hidden gathered across the previous time series?

Upvotes: 0

Views: 245

Answers (1)

amin__
amin__

Reputation: 1058

Yes, backprop happen across time steps in dynamic_rnn.

But, I think you study the inputs parameter of dynamic_rnn. it should be of shape [batch_size, max_time, ...]. And when you call dynamic_rnn with an input like that shape it calls your lstm_cell max_time times using initial states provided by you as rnn_state_in.

Remember, at each time steps dynamic_rnn takes c and h states from the previous time step automatically. so you don't have to feed them each time inside sess.run(..). You need to only feed inputs.

And backprop across all time steps will be calculated when you calculate a loss using final state (or all states) of your lstm and use an optimizer like SGD or adam.

Upvotes: 1

Related Questions