Reputation: 168
I am building a preprocessing and data augmentation pipeline for my image segmentation dataset There is a powerful API from keras to do this but I ran into the problem of reproducing same augmentation on image as well as segmentation mask (2nd image). Both images must undergo the exact same manipulations. Is this not supported yet?
Example / Pseudocode
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])
(train_ds, test_ds), info = tfds.load('somedataset', split=['train[:80%]', 'train[80%:]'], with_info=True)
This code does not work but illustrates how my dream api would work:
train_ds = datapoint: data_augmentation((datapoint['image'], datapoint['segmentation_mask']), training=True))
The alternative is to code a custom load and manipulation / randomization method as is proposed in the image segmentation tutorial (
Any tips on state of the art data augmentation for this type of dataset is much appreciated :)
Upvotes: 4
Views: 6872
Reputation: 17
This is the method described in the official docs, [Image Segmentation Official Tutorials][1]
class Augment(tf.keras.layers.Layer):
def __init__(self, seed=42):
self.augment_inputs = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])
self.augment_labels = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])
def call(self, inputs, labels):
inputs = self.augment_inputs(inputs)
labels = self.augment_labels(labels)
return inputs, labels
After this, you can call the Augment() func
train_batches = (
This will make sure that your inputs and masks are equally randomly augmented. [1]:
Upvotes: -1
Reputation: 110
I solved this by using concat, to create one image and then using augmentation layers.
def augment_using_layers(images, mask, size=None):
if size is None:
h_s = mask.shape[0]
w_s = mask.shape[1]
h_s = size[0]
w_s = size[1]
def aug(height=h_s, width=w_s):
flip = tf.keras.layers.RandomFlip(mode="horizontal")
rota = tf.keras.layers.RandomRotation(0.2, fill_mode='constant')
zoom = tf.keras.layers.RandomZoom(
height_factor=(-0.05, -0.15),
width_factor=(-0.05, -0.15)
trans = tf.keras.layers.RandomTranslation(height_factor=(-0.1, 0.1),
width_factor=(-0.1, 0.1),
crop = tf.keras.layers.RandomCrop(h_s, w_s)
layers = [flip, zoom, crop, trans, rota]
aug_model = tf.keras.Sequential(layers)
return aug_model
aug = aug()
mask = tf.stack([mask, mask, mask], -1)
mask = tf.cast(mask, 'float32')
images_mask = tf.concat([images, mask], -1)
images_mask = aug(images_mask)
image = images_mask[:,:,0:3]
mask = images_mask[:,:,4]
return image, tf.cast(mask, 'uint8')
Then you can map your dataset:
# create dataset
dataset =
dataset = x: load_dataset(x, (400, 400)))
# aug. dataset
dataset_aug = x, y: augment_using_layers(x, y, (400, 400)))
Upvotes: 1
Reputation: 41
def Augment(tar_shape=(512,512), seed=37):
img = tf.keras.Input(shape=(None,None,3))
msk = tf.keras.Input(shape=(None,None,1))
i = tf.keras.layers.RandomFlip(seed=seed)(img)
m = tf.keras.layers.RandomFlip(seed=seed)(msk)
i = tf.keras.layers.RandomTranslation((-0.75, 0.75),(-0.75, 0.75),seed=seed)(i)
m = tf.keras.layers.RandomTranslation((-0.75, 0.75),(-0.75, 0.75),seed=seed)(m)
i = tf.keras.layers.RandomRotation((-0.35, 0.35),seed=seed)(i)
m = tf.keras.layers.RandomRotation((-0.35, 0.35),seed=seed)(m)
i = tf.keras.layers.RandomZoom((-0.1, 0.05),(-0.1, 0.05),seed=seed)(i)
m = tf.keras.layers.RandomZoom((-0.1, 0.05),(-0.1, 0.05),seed=seed)(m)
i = tf.keras.layers.RandomCrop(tar_shape[0],tar_shape[1],seed=seed)(i)
m = tf.keras.layers.RandomCrop(tar_shape[0],tar_shape[1],seed=seed)(m)
return tf.keras.Model(inputs=(img,msk), outputs=(i,m))
Augment = Augment()
ds_train = img,msk: Augment((img,msk)), num_parallel_calls=AUTOTUNE)
Upvotes: 1
Reputation: 168
Here is my own implementation in case someone else wants to use tf built-ins (tf.image api) as of decembre 2020 :)
def load_image(datapoint, augment=True):
# resize image and mask
img_orig = input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))
mask_orig = input_mask = tf.image.resize(datapoint['segmentation_mask'], (IMG_SIZE, IMG_SIZE))
# rescale the image
input_image = tf.image.rgb_to_grayscale(input_image)
input_image = tf.cast(input_image, tf.float32) / 255.0
# augmentation
if augment:
# zoom in a bit
if tf.random.uniform(()) > 0.5:
# use original image to preserve high resolution
input_image = tf.image.central_crop(img_orig, 0.75)
input_mask = tf.image.central_crop(mask_orig, 0.75)
# resize
input_image = tf.image.resize(input_image, (IMG_SIZE, IMG_SIZE))
input_mask = tf.image.resize(input_mask, (IMG_SIZE, IMG_SIZE))
# random brightness adjustment illumination
input_image = tf.image.random_brightness(input_image, 0.3)
# random contrast adjustment
input_image = tf.image.random_contrast(input_image, 0.2, 0.5)
# flipping random horizontal or vertical
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_up_down(input_image)
input_mask = tf.image.flip_up_down(input_mask)
# rotation in 30° steps
rot_factor = tf.cast(tf.random.uniform(shape=[], maxval=12, dtype=tf.int32), tf.float32)
angle = np.pi/12*rot_factor
input_image = tfa.image.rotate(input_image, angle)
input_mask = tfa.image.rotate(input_mask, angle)
return input_image, input_mask
Upvotes: 4
Reputation: 4258
You can try with external libraries for extra image augmentations. These links may help for image augmentation along with segmentation mask,
Upvotes: 2