Hasnain Raza
Hasnain Raza

Reputation: 681

Correct way of doing data augmentation in TensorFlow with the dataset api?

So, I've been playing around with the TensorFlow dataset API for loading images, and segmentation masks (for a semantic segmentation project), I would like to be able to generate batches of images and masks, with each image having randomly gone through any combination of pre-processing functions like brightness changes, contrast changes, cropping, saturation changes etc. So, the first image in my batch may have no pre-processing, second may have saturation changes, third may have brightness and saturation and so on.

I tried the following:

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
import random


def _resize_image(image, mask):
    image = tf.image.resize_bicubic(image, [480, 640], True)
    mask = tf.image.resize_bicubic(mask, [480, 640], True)
    return image, mask

def _corrupt_contrast(image, mask):
    image = tf.image.random_contrast(image, 0, 5)
    return image, mask


def _corrupt_saturation(image, mask):
    image = tf.image.random_saturation(image, 0, 5)
    return image, mask


def _corrupt_brightness(image, mask):
    image = tf.image.random_brightness(image, 5)
    return image, mask


def _random_crop(image, mask):
    seed = random.random()
    image = tf.random_crop(image, [240, 320, 3], seed=seed)
    mask = tf.random_crop(mask, [240, 320, 1], seed=seed)
    return image, mask


def _flip_image_horizontally(image, mask):
    seed = random.random()
    image = tf.image.random_flip_left_right(image, seed=seed)
    mask = tf.image.random_flip_left_right(mask, seed=seed)

    return image, mask


def _flip_image_vertically(image, mask):
    seed = random.random()
    image = tf.image.random_flip_up_down(image, seed=seed)
    mask = tf.image.random_flip_up_down(mask, seed=seed)

    return image, mask


def _normalize_data(image, mask):
    image = tf.cast(image, tf.float32)
    image = image / 255.0

    mask = tf.cast(mask, tf.float32)
    mask = mask / 255.0

    return image, mask


def _parse_data(image_paths, mask_paths):
    image_content = tf.read_file(image_paths)
    mask_content = tf.read_file(mask_paths)

    images = tf.image.decode_png(image_content, channels=3)
    masks = tf.image.decode_png(mask_content, channels=1)

    return images, masks


def data_batch(image_paths, mask_paths, params, batch_size=4, num_threads=2):
    # Convert lists of paths to tensors for tensorflow
    images_name_tensor = tf.constant(image_paths)
    mask_name_tensor = tf.constant(mask_paths)

    # Create dataset out of the 2 files:
    data = Dataset.from_tensor_slices(
        (images_name_tensor, mask_name_tensor))

    # Parse images and labels
    data = data.map(
        _parse_data, num_threads=num_threads, output_buffer_size=6 * batch_size)

    # Normalize images and masks for vals. between 0 and 1
    data = data.map(_normalize_data, num_threads=num_threads, output_buffer_size=6 * batch_size)

    if params['crop'] and not random.randint(0, 1):
        data = data.map(_random_crop, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['brightness'] and not random.randint(0, 1):
        data = data.map(_corrupt_brightness, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['contrast'] and not random.randint(0, 1):
        data = data.map(_corrupt_contrast, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['saturation'] and not random.randint(0, 1):
        data = data.map(_corrupt_saturation, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['flip_horizontally'] and not random.randint(0, 1):
        data = data.map(_flip_image_horizontally,
                    num_threads=num_threads, output_buffer_size=6 * batch_size)

    if params['flip_vertically'] and not random.randint(0, 1):
        data = data.map(_flip_image_vertically, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    # Shuffle the data queue
    data = data.shuffle(len(image_paths))

    # Create a batch of data
    data = data.batch(batch_size)

    data = data.map(_resize_image, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    # Create iterator
    iterator = Iterator.from_structure(data.output_types, data.output_shapes)

    # Next element Op
    next_element = iterator.get_next()

    # Data set init. op
    init_op = iterator.make_initializer(data)

    return next_element, init_op

But all batches returned by this have the same transformations applied to them, not different combinations, my guess is that the random.randint persists, and is not actually run for each batch, if so, how do I fix this to get the desired result? For an example of how I plan to use it (I feel that's irrelevant to the problem but people might still want to know) can be found here

Upvotes: 14

Views: 4319

Answers (1)

Hasnain Raza
Hasnain Raza

Reputation: 681

So the problem was indeed that the control flow with the if statements are with Python variables, and are only executed once when the graph is created, to do what I want to do, I had to define a placeholder that contains the boolean values of whether to apply a function or not (and feed in a new boolean tensor per iteration to change the augmentation), and control flow is handled by tf.cond. I pushed the new code to the GitHub link I posted in the question above if anyone is interested.

Upvotes: 11

Related Questions