ajana
ajana

Reputation: 31

How to perform 10 Crop Image Augmentation at training time using Tensorflow 2.0 Dataset

I am using Tensorflow Dataset API and reading data from TFRecord files. I can use the map function and use method like random_flip_left_right, random_crop for data augmentation.

However when I am trying to replicate AlexNet paper I am facing an issue. I need to flip each image and then take 5 crops ( left, top, bottom, right & middle).

So the input dataset size will increase by 10 times. Is there anyway to do this using tensorflow dataset API? The map() function just returns the one image and I am not able to increase the number of images.

Please see the code I have now.

dataset = dataset.map(parse_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    .map(lambda image, label: (tf.image.random_flip_left_right(image), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    .map(lambda image, label: (tf.image.random_crop(image, size=[227, 227, 3]), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) \
    .shuffle(buffer_size=1000) \
    .repeat() \
    .batch(256) \
    .prefetch(tf.data.experimental.AUTOTUNE)

Upvotes: 3

Views: 1518

Answers (1)

Kutay YILDIZ
Kutay YILDIZ

Reputation: 111

def tile_crop(img, label):
    img_shape = tf.shape(img)
    crop_left = lambda img: tf.image.random_crop(img[:,:img_shape[1]//2,:], size=[227,227,3])
    crop_top = lambda img: tf.image.random_crop(img[:img_shape[0]//2,:,:], size=[227,227,3])
    ...
    img = tf.image.random_flip_left_right(img)
    img = tf.stack([crop_left(img), crop_top(img),...], axis=0])
    label = tf.reshape(label, [1,1]) #size: (,) -> (1,1)
    label = tf.tile(label, [5, 1]) #size: (1,1) -> (5,1)
    return img, label
dt = parsed_dataset.map(tile_crop) #size: ((5,height,width,channels), (5, 1))
dt = dt.unbatch() #size: ((height,width,channels), (1))

You can then use shuffle/repeat/batch/prefetch as u like. Make sure that every cropped image has the same size.

Upvotes: 2

Related Questions