Reputation: 502
How should I implement an augmentation pipeline in which my dataset gets extended instead of replacing the images with the augmented ones, that means, how to use map calls to augment and preserve the original samples?
records_path = DATA_DIR+'/'+'TFRecords'+TRAIN+'train_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5),label))
dataset = dataset.shuffle(100)
dataset = dataset.batch(2)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
I was expecting from the code above that by iterating through batches I would get the original image and its cropped version, besides that i guess that i haven't properly understand how the cache method behaves.
Then I have used the code below to exhibit the images, that plots random cropped images.
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
for i in range(10):
image,label = iterator.get_next()
img_array = image[0].numpy()
plt.imshow(img_array)
plt.show()
print('label: ', label[0])
img_array = image[1].numpy()
plt.imshow(img_array)
plt.show()
print('label: ', label[1])
Upvotes: 0
Views: 291
Reputation: 741
In your case, the cache()
allows to keep the dataset after applying parsing_fn
in memory. It only helps on improving the performance. Once you iterate over the whole dataset, everything image is kept in memory. So, the next iteration will be faster as you won't have to apply parsing_fn
to it again.
If you intend to get the original image and its crop when iterating over the dataset, what you have to do is to return both the image and its crop in your map()
function:
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5), image ,label))
Then, in your iteration, you can retrieve both:
for i in range(10):
crop, image, label = iterator.get_next()
Upvotes: 1