Reputation: 1241
I have a code like this:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple = True)
c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c], "c_in")
h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h], "h_in")
rnn_state_in = (c_in, h_in)
rnn_in = tf.expand_dims(previous_layer, [0])
sequence_length = #size of my batch
rnn_state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(lstm_cell,
rnn_in,
initial_state = rnn_state_in,
sequence_length = sequence_length,
time_major = False)
lstm_c, lstm_h = lstm_state
rnn_out = tf.reshape(lstm_outputs, [-1, 256])
Here, I use dynamic_rnn to simulate the time steps from the batch.
While each forward pass, I can get lstm_c, lstm_h
which I can store anywhere outside.
So, suppose I have done a forward pass for N items in a sequence in my network and have the final cell state and hidden state provided from the dynamic_rnn. Now, to perform back propagation, what should be my input to the LSTMs?
By default, does backprop happen across time steps in dynamic_rnn?
(say, no. of time steps = batch_size=N)
So is it enough for me to provide the input as:
sess.run(_train_op, feed_dict = {_state: np.vstack(batch_states),
...
c_in: prev_rnn_state[0],
h_in: prev_rnn_state[1]
})
(where prev_rnn_state
is a tuple of cell state, hidden state
, which I got from the dynamic_rnn from forward propagation for the previous batch.)
Or do I have unroll the LSTM layer across time series explicitly and train it by providing a vector of the cell states and hidden gathered across the previous time series?
Upvotes: 0
Views: 245
Reputation: 1058
Yes, backprop happen across time steps in dynamic_rnn.
But, I think you study the inputs
parameter of dynamic_rnn
. it should be of shape [batch_size, max_time, ...]
. And when you call dynamic_rnn with an input like that shape it calls your lstm_cell max_time
times using initial states provided by you as rnn_state_in
.
Remember, at each time steps dynamic_rnn
takes c and h states from the previous time step automatically. so you don't have to feed them each time inside sess.run(..). You need to only feed inputs.
And backprop across all time steps will be calculated when you calculate a loss using final state (or all states) of your lstm and use an optimizer like SGD or adam.
Upvotes: 1