rodrigo-silveira
rodrigo-silveira

Reputation: 13088

How to perform data augmentation in Tensorflow Estimator's input_fn

Using Tensorflow's Estimator API, at what point in the pipeline should I perform the data augmentation?

According to this official Tensorflow guide, one place to perform the data augmentation is in the input_fn:

def parse_fn(example):
  "Parse TFExample records and perform simple data augmentation."
  example_fmt = {
    "image": tf.FixedLengthFeature((), tf.string, ""),
    "label": tf.FixedLengthFeature((), tf.int64, -1)
  }
  parsed = tf.parse_single_example(example, example_fmt)
  image = tf.image.decode_image(parsed["image"])

  # augments image using slice, reshape, resize_bilinear
  #         |
  #         |
  #         |
  #         v
  image = _augment_helper(image)

  return image, parsed["label"]

def input_fn():
  files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
  dataset = files.interleave(tf.data.TFRecordDataset)
  dataset = dataset.map(map_func=parse_fn)
  # ...
  return dataset

My question

If I perform data augmentation inside input_fn, does parse_fn return a single example or a batch including the original input image + all of the augmented variants? If it should only return a single [augmented] example, how do I ensure that all images in the dataset are used in its un-augmented form, as well as all variants?

Upvotes: 3

Views: 1140

Answers (2)

Luciano Dourado
Luciano Dourado

Reputation: 501

It will return single examples for every call you make to the parse_fn, then if you use the .batch() operation it will return a batch of parsed images

Upvotes: 0

NiallJG
NiallJG

Reputation: 1961

If you use iterators on your dataset, your _augment_helper function will be called with each iteration of the dataset across each block of data fed in ( as you are calling the parse_fn in dataset.map )

Change your code to

  ds_iter = dataset.make_one_shot_iterator()
  ds_iter = ds_iter.get_next()
  return ds_iter

I've tested this with a simple augmentation function

  def _augment_helper(image):
       print(image.shape)
       image = tf.image.random_brightness(image,255.0, 1)
       image = tf.clip_by_value(image, 0.0, 255.0)
       return image

Change 255.0 to whatever the maximum value is in your dataset, I used 255.0 as my example's data set was in 8 bit pixel values

Upvotes: 1

Related Questions