Evan Weissburg
Evan Weissburg

Reputation: 1594

How to find and run the largest batch in a dataset before starting training

Question

In Tensorflow, I frequently run into OOM errors during the first epoch of training. However, the large nature of the network causes the first epoch to take around an hour, far to long to test new hyper-parameters quickly.

Ideally, I'd like to be able to sort the iterator so that I can just run get_next() once on the largest batch.

How can I do this? Or perhaps there is a better way to implement fail early?

The iterator is in the format: (source, tgt_in, tgt_out, key_weights, source_len, target_len) where I'm looking to sort by target length. It is padded and batched before being returned.

The dataset is a list of sentences, bucketed with similar lengths. I would like to find the largest batch in the iterator and run only it.

Some Code

The below code would work if the initializer didn't shuffle the iterator every time, thus destroying the information gained about the position of the largest batch. I'm not quite sure how to modify it -- as soon as one reads the length of the batch using get_next(), it has already been "popped" and can't be used as input into the model anymore.

def verify_hparams():
    train_sess.run(train_model.iterator.initializer)
    max_index = -1
    max_len = 0
    for batch in itertools.count():
        try:
            batch_len = np.amax(train_sess.run(train_model.iterator.get_next()[-1]))
            if batch_len > max_len:
                max_len = batch_len
                max_index = batch

        except tf.errors.OutOfRangeError:
            num_batches = batch + 1
            break

    for batch in range(-1, num_batches-1):
        try:
            if batch is max_index:
                _, _ = loaded_train_model.train(train_sess)
            else:
                train_sess.run(train_model.iterator.get_next())

        except tf.errors.OutOfRangeError:
            break

    return num_batches

Upvotes: 1

Views: 151

Answers (1)

bremen_matt
bremen_matt

Reputation: 7349

What you need is a "peek" operation. Most languages have iterators which allow you to peek and see if there is more data (something like iterator.hasNext()). But the functionality you are asking for is essentially something like iterator.sizeOfNext(). To my knowledge, the tensorflow iterators don't have such functionality.

Furthermore, such functionality is unlikely to be add because I can imagine there are generators which can't provide such functionality, and so adding this feature would break backwards compatibility.

Upvotes: 1

Related Questions