Reputation: 1594
Question
In Tensorflow, I frequently run into OOM errors during the first epoch of training. However, the large nature of the network causes the first epoch to take around an hour, far to long to test new hyper-parameters quickly.
Ideally, I'd like to be able to sort the iterator so that I can just run get_next()
once on the largest batch.
How can I do this? Or perhaps there is a better way to implement fail early?
The iterator is in the format: (source, tgt_in, tgt_out, key_weights, source_len, target_len)
where I'm looking to sort by target length. It is padded and batched before being returned.
The dataset is a list of sentences, bucketed with similar lengths. I would like to find the largest batch in the iterator and run only it.
Some Code
The below code would work if the initializer didn't shuffle the iterator every time, thus destroying the information gained about the position of the largest batch. I'm not quite sure how to modify it -- as soon as one reads the length of the batch using get_next()
, it has already been "popped" and can't be used as input into the model anymore.
def verify_hparams():
train_sess.run(train_model.iterator.initializer)
max_index = -1
max_len = 0
for batch in itertools.count():
try:
batch_len = np.amax(train_sess.run(train_model.iterator.get_next()[-1]))
if batch_len > max_len:
max_len = batch_len
max_index = batch
except tf.errors.OutOfRangeError:
num_batches = batch + 1
break
for batch in range(-1, num_batches-1):
try:
if batch is max_index:
_, _ = loaded_train_model.train(train_sess)
else:
train_sess.run(train_model.iterator.get_next())
except tf.errors.OutOfRangeError:
break
return num_batches
Upvotes: 1
Views: 151
Reputation: 7349
What you need is a "peek" operation. Most languages have iterators which allow you to peek and see if there is more data (something like iterator.hasNext()
). But the functionality you are asking for is essentially something like iterator.sizeOfNext()
. To my knowledge, the tensorflow iterators don't have such functionality.
Furthermore, such functionality is unlikely to be add because I can imagine there are generators which can't provide such functionality, and so adding this feature would break backwards compatibility.
Upvotes: 1