Luciano Dourado
Luciano Dourado

Reputation: 502

Mapping dataset to augmentation function does not preserve my original samples

How should I implement an augmentation pipeline in which my dataset gets extended instead of replacing the images with the augmented ones, that means, how to use map calls to augment and preserve the original samples?

threads I've checked: 1, 2

Code I'm currently using:

records_path = DATA_DIR+'/'+'TFRecords'+TRAIN+'train_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5),label))
dataset = dataset.shuffle(100)
dataset = dataset.batch(2)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) 

I was expecting from the code above that by iterating through batches I would get the original image and its cropped version, besides that i guess that i haven't properly understand how the cache method behaves.

Then I have used the code below to exhibit the images, that plots random cropped images.

iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) 

for i in range(10):
    image,label = iterator.get_next()
    img_array = image[0].numpy()    
    plt.imshow(img_array)
    plt.show()
    print('label: ', label[0])

    img_array = image[1].numpy()    
    plt.imshow(img_array)
    plt.show()
    print('label: ', label[1])

Upvotes: 0

Views: 291

Answers (1)

Raphael Meudec
Raphael Meudec

Reputation: 741

In your case, the cache() allows to keep the dataset after applying parsing_fn in memory. It only helps on improving the performance. Once you iterate over the whole dataset, everything image is kept in memory. So, the next iteration will be faster as you won't have to apply parsing_fn to it again.

If you intend to get the original image and its crop when iterating over the dataset, what you have to do is to return both the image and its crop in your map() function:

dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5), image ,label))

Then, in your iteration, you can retrieve both:

for i in range(10):
    crop, image, label = iterator.get_next()

Upvotes: 1

Related Questions