haofeng
haofeng

Reputation: 651

How to set possbility to tf.keras.layers.RandomFlip?

Is there possible to set a possibility when doing random flip operations by using tf.keras.layers.RandomFlip ?

for example:

def augmentation():
        data_augmentation = keras.Sequential([
            keras.layers.RandomFlip("horizontal", p=0.5),
            keras.layers.RandomRotation(0.2, p=0.5)
        ])
   return data_augmentation 

Upvotes: 2

Views: 2108

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

Try creating a simple Lambda layer and defining your probability in a separate function:

import random

def random_flip_on_probability(image, probability= 0.5):
    if random.random() < probability:
      return tf.image.random_flip_left_right(image)
    return image

def augmentation():
        data_augmentation = keras.Sequential([
            keras.layers.Lambda(random_flip_on_probability),
            keras.layers.RandomRotation(0.2, p=0.5)
        ])
   return data_augmentation 

If you need to use data augmentation during training or inference, you will have to define your own custom layer. Try something like this:

import tensorflow as tf
import pathlib

class RandomFlipOnProbability(tf.keras.layers.Layer):
  def __init__(self, probability):
    super(RandomFlipOnProbability, self).__init__()
    self.probability = probability

  def call(self, images):
    return tf.cond(tf.random.uniform(()) < self.probability, lambda: tf.image.flip_left_right(images), lambda: images)

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)


random_layer = RandomFlipOnProbability(probability = 0.9)
normalization_layer = tf.keras.layers.Rescaling(1./255)

images, _ = next(iter(train_ds.take(1)))
images = normalization_layer(random_layer(images))
image = images[0]

plt.imshow(image.numpy())

enter image description here

Upvotes: 2

Related Questions