Tanmoy Dewanjee
Tanmoy Dewanjee

Reputation: 13

How can I apply the same augmentation to a batch of images?

I have a dataset of videos. Since the dataset is small, I am trying to augment the video data. I have not found any resources on augmenting videos, so what I think will work is -

  1. Extract required frames from the video
  2. Apply data augmentation to the extracted frames

Now, let's say I have extracted 20 frames from a single video. In order for my data to make sense, I will have to apply the same augmentation to these 20 frames. How can I achieve that? I am also open to other libraries if it makes the work easy.

I am guessing some changes to the ImageDataGenerator.flow_from_directory(...) arguments will do the trick. Here's the code snippet from Keras documentation.

ImageDataGenerator.flow_from_directory(
    directory,
    target_size=(256, 256),
    color_mode="rgb",
    classes=None,
    class_mode="categorical",
    batch_size=32,
    shuffle=True,
    seed=None,
    save_to_dir=None,
    save_prefix="",
    save_format="png",
    follow_links=False,
    subset=None,
    interpolation="nearest",
)

Thank you in advance!

Upvotes: 1

Views: 1845

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36704

You can use a tf.data.Dataset, and apply transformations after the batching operation. This will require some work to make your own directory iterator (something like this), but here's the essence of it:

import tensorflow as tf
import matplotlib.pyplot as plt
from skimage import data

cats = tf.concat([data.chelsea()[None, ...] for i in range(24)], axis=0)

test = tf.data.Dataset.from_tensor_slices(cats)


def augment(tensor):
    tensor = tf.cast(x=tensor, dtype=tf.float32)
    tensor = tf.divide(x=tensor, y=tf.constant(255.))
    tensor = tf.image.random_hue(image=tensor, max_delta=5e-1)
    tensor = tf.image.random_brightness(image=tensor, max_delta=2e-1)
    return tensor


test = test.batch(8).map(lambda x: augment(x))


fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images = next(iter(test))
for index, image in enumerate(images):
    ax = plt.subplot(4, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(tf.clip_by_value(image, clip_value_min=0, clip_value_max=1))
plt.show()

enter image description here

Not that for some reason, this doesn't work for tf.image.random_flip_left_right.

Upvotes: 1

Related Questions