Reputation: 51
I am trying to train a CNN using my own dataset. I've been using tfrecord files and the tf.data.TFRecordDataset API to handle my dataset. It works fine for my training dataset. But when I tried to batch my validation dataset, the error of 'OutOfRangeError: End of sequence' raised. After browsing through the Internet, I thought the problem was caused by the batch size of the validation set, which I set to 32 in the first place. But after I changed it to 2, the code ran for like 9 epochs and the error raised again.
I used an input function to handle the dataset, the code goes below:
def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
if is_training:
dataset = dataset.shuffle(buffer_size=1500)
dataset = dataset.map(parse_record)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
and for the training set, "batch_size" is set to 128 and "num_epochs" set to None which means keep repeating for infinite time. For the validation set, "batch_size" is set to 32(later set to 2, still didn't work) and the "num_epochs" set to 1 since I only want to go through the validation set one time. I can assure that the validation set contains enough data for the epochs. Because I've tried the codes below and it didn't raise any errors:
with tf.Session() as sess:
features, labels = input_fn(False, valid_list, 32, 1, 1)
for i in range(450):
sess.run([features, labels])
print(labels.shape)
In the code above, when I changed the number 450 to 500 or anything larger, it would raise the 'OutOfRangeError'. That can confirm that my validation dataset contains enough data for 450 iterations with a batch size of 32.
I've tried to use a smaller batch size(i.e., 2) for the validation set, but still having the same error. I can get the code running with the "num_epochs" set to "None" in the input_fn for validation set, but that does not seem to be how the validation works. Any help, please?
Upvotes: 5
Views: 12514
Reputation: 1680
This behaviour is normal. From the Tensorflow documentation:
If the iterator reaches the end of the dataset, executing the
Iterator.get_next()
operation will raise atf.errors.OutOfRangeError
. After this point the iterator will be in an unusable state, and you must initialize it again if you want to use it further.
The reason why the error is not raised when you set dataset.repeat(None)
is because the dataset is never exhausted since it is repeated indefinitely.
To solve your issue, you should change your code to this:
n_steps = 450
...
with tf.Session() as sess:
# Training
features, labels = input_fn(True, training_list, 32, 1, 1)
for step in range(n_steps):
sess.run([features, labels])
...
...
# Validation
features, labels = input_fn(False, valid_list, 32, 1, 1)
try:
sess.run([features, labels])
...
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
You can also make a few changes to your input_fn to run the evaluation at every epoch:
def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
if is_training:
dataset = dataset.shuffle(buffer_size=1500)
dataset = dataset.map(parse_record)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
return iterator
n_epochs = 10
freq_eval = 1
training_iterator = input_fn(True, training_list, 32, 1, 1)
training_features, training_labels = training_iterator.get_next()
val_iterator = input_fn(False, valid_list, 32, 1, 1)
val_features, val_labels = val_iterator.get_next()
with tf.Session() as sess:
# Training
sess.run(training_iterator.initializer)
for epoch in range(n_epochs):
try:
sess.run([training_features, training_labels])
except tf.errors.OutOfRangeError:
pass
# Validation
if (epoch+1) % freq_eval == 0:
sess.run(val_iterator.initializer)
try:
sess.run([val_features, val_labels])
except tf.errors.OutOfRangeError:
pass
I advise you to have a close look to this official guide if you want to have a better understanding of what is happening under the hood.
Upvotes: 7