ivan
ivan

Reputation: 111

Tensorflow CNN image augmentation pipeline

I'm trying to learn the new Tensorflow APIs and I am a bit lost on where to get a handle on my input batch tensors so I can manipulate and augment them with for example tf.image.

This is the my current network & pipeline:

trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays

#...
train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))

#...
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)

#...defining cnn architecture...

# In the train loop
TrainLoop {
   sess.run(train_init_op)  # switching to train data
   sess.run(train_step, ...) # running a train step

   #... 
   sess.run(test_init_op)  # switching to test data
   test_loss = sess.run(loss, ...) # printing test loss after epoch
}

I'm using the Dataset API creating 2 datasets so that in the trainloop I can calculate the train and test loss and log them.

Where in this pipeline would I manipulate and distort my input batch of images? I'm not creating any tf.placeholders for my trainX input batches so I can't manipulate them with tf.image because for example tf.image.flip_up_down requires a 3-D or 4-D tensor.

Upvotes: 0

Views: 4097

Answers (1)

DomJack
DomJack

Reputation: 4183

There's a really good article and talk released recently that go over the API in a lot more detail than my response here. Here's a brief example:

import tensorflow as tf
import numpy as np


def read_data():
    n_train = 100
    n_test = 50
    height = 20
    width = 30
    channels = 3
    trainX = (np.random.random(
        size=(n_train, height, width, channels)) * 255).astype(np.uint8)
    testX = (np.random.random(
            size=(n_test, height, width, channels))*255).astype(np.uint8)
    trainY = (np.random.random(size=(n_train,))*10).astype(np.int32)
    testY = (np.random.random(size=(n_test,))*10).astype(np.int32)
    return trainX, testX, trainY, testY


trainX, testX, trainY, testY = read_data()
# trainX [num_image, height, width, channels], these are numpy arrays


train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))


def map_single(x, y):
    print('Map single:')
    print('x shape: %s' % str(x.shape))
    print('y shape: %s' % str(y.shape))
    x = tf.image.per_image_standardization(x)
    # Consider: x = tf.image.random_flip_left_right(x)
    return x, y


def map_batch(x, y):
    print('Map batch:')
    print('x shape: %s' % str(x.shape))
    print('y shape: %s' % str(y.shape))
    # Note: this flips ALL images left to right. Not sure this is what you want
    # UPDATE: looks like tf documentation is wrong and you need a 3D tensor?
    # return tf.image.flip_left_right(x), y
    return x, y


batch_size = 32
train_dataset = train_dataset.repeat().shuffle(100)
train_dataset = train_dataset.map(map_single, num_parallel_calls=8)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.map(map_batch)
train_dataset = train_dataset.prefetch(2)

test_dataset = test_dataset.map(
        map_single, num_parallel_calls=8).batch(batch_size).map(map_batch)
test_dataset = test_dataset.prefetch(2)


iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
                 train_dataset.output_shapes)
features, labels = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)


with tf.Session() as sess:
    sess.run(train_init_op)
    feat, lab = sess.run((features, labels))

    print(feat.shape)
    print(lab.shape)

    sess.run(test_init_op)
    feat, lab = sess.run((features, labels))

    print(feat.shape)
    print(lab.shape)    

A few notes:

  1. This approach relies on being able to load your entire dataset into memory. If you cannot, consider using tf.data.Dataset.from_generator. This can lead to slow shuffle times if your shuffle buffer is large. My preferred method is to load some keys tensor entirely into memory - it might just be the indices of each example - then map that key value to data values using tf.py_func. This is slightly less efficient than converting to tfrecords, but with prefetching it likely won't affect performance. Since the shuffling is done before the mapping, you only have to load shuffle_buffer keys into memory, rather than shuffle_buffer examples.
  2. To augment your dataset, use tf.data.Dataset.map either before or after the batch operation, depending on whether or not you want to apply a batch-wise operation (something working on a 4D image tensor) or element-wise operation (3D image tensor). Note it looks like the documentation for tf.image.flip_left_right is out of date, since I get an error when I try and use a 4D tensor. If you want to augment you data randomly, use tf.image.random_flip_left_right rather than tf.image.flip_left_right.
  3. If you're using a tf.estimator.Estimator (or wouldn't mind converting your code to using it), then check out tf.estimator.train_and_evaluate for an in-built way of switching between datasets.
  4. Consider shuffling/repeating your dataset with the shuffle/repeat methods. See the article for notes on efficiencies. In particular, repeat -> shuffle -> map -> batch -> batch-wise map -> prefetch seems to be the best ordering of operations for most applications.

Upvotes: 5

Related Questions