Gemini
Gemini

Reputation: 475

How are the batches iterated in the PTB LSTM example of Tensorflow?

I am currently trying to understand the LSTM tutorial from Tensorflow and have a question about the code of https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/ptb/ptb_word_lm.py. In the function run_epoch() these lines run an epoch, input.epoch_size is actually the number of batches of the data:

for step in range(model.input.epoch_size):
    feed_dict = {}
    for i, (c, h) in enumerate(model.initial_state):
      feed_dict[c] = state[i].c
      feed_dict[h] = state[i].h

    vals = session.run(fetches, feed_dict)
    cost = vals["cost"]
    state = vals["final_state"]

    costs += cost
    iters += model.input.num_steps

    if verbose and step % (model.input.epoch_size // 10) == 10:
      print("%.3f perplexity: %.3f speed: %.0f wps" %
            (step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
             iters * model.input.batch_size / (time.time() - start_time)))

But I wonder, how does this code "tell" the LSTM model in which epoch we are? In the LSTM class just in the init method the whole data was loaded and the computation is defined generally on the data.

My second question is the feeding of c and h to the computation. Why do we do that? Does it have something to do with stateful vs stateless LSTM? So could I remove that code safely for a vanilla LSTM?

Thanks!

Upvotes: 2

Views: 667

Answers (2)

zhufeida008
zhufeida008

Reputation: 1

# for i, (c, h) in enumerate(model.initial_state):
#   feed_dict[c] = state[i].c
#   feed_dict[h] = state[i].h

feed_dict[model._initial_state]=state;

For loop is to initial the first cell state of current batch with the final cell state of last batch.

Upvotes: 0

martianwars
martianwars

Reputation: 6500

If you see line 348 in the same file, the code is calling run_epoch() once for each epoch. Each epoch, the LSTM cell is initialized to an all zero state as the training proceeds. Coming to your questions,

But I wonder, how does this code "tell" the LSTM model in which epoch we are? In the LSTM class just in the init method the whole data was loaded and the computation is defined generally on the data.

The weights inside the LSTM cell are being updated and at the beginning of each epoch the initial_state of the LSTM is used. There is no need to explicitely tell the LSTM the epoch number.

My second question is the feeding of c and h to the computation. Why do we do that? Does it have something to do with stateful vs stateless LSTM? So could I remove that code safely for a vanilla LSTM?

This is a very important step. This is done to pass the LSTM state across various batches. An LSTM has two internal states, c and h. When these are fed to the graph, the previous batch's final state becomes the initial state for the next batch. You can replace this by just calculating model.final_state and passing it the next time in feed_dict. If you look at the TensorFlow code, state is essentially a tuple of c and h when state_is_tuple is True, as you can read here.

Upvotes: 4

Related Questions