Tom
Tom

Reputation: 168

Keras data augmentation pipeline for image segmentation dataset (image and mask with same manipulation)

I am building a preprocessing and data augmentation pipeline for my image segmentation dataset There is a powerful API from keras to do this but I ran into the problem of reproducing same augmentation on image as well as segmentation mask (2nd image). Both images must undergo the exact same manipulations. Is this not supported yet?

https://www.tensorflow.org/tutorials/images/data_augmentation

Example / Pseudocode

data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])

(train_ds, test_ds), info = tfds.load('somedataset', split=['train[:80%]', 'train[80%:]'], with_info=True)

This code does not work but illustrates how my dream api would work:

train_ds = train_ds.map(lambda datapoint: data_augmentation((datapoint['image'], datapoint['segmentation_mask']), training=True))

Alternative

The alternative is to code a custom load and manipulation / randomization method as is proposed in the image segmentation tutorial (https://www.tensorflow.org/tutorials/images/segmentation)

Any tips on state of the art data augmentation for this type of dataset is much appreciated :)

Upvotes: 4

Views: 6872

Answers (5)

R Kalia
R Kalia

Reputation: 17

This is the method described in the official docs, [Image Segmentation Official Tutorials][1]

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    self.augment_inputs = tf.keras.Sequential([
                          layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
                          layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
                          layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])
    self.augment_labels = tf.keras.Sequential([
                          layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
                          layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
                          layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

After this, you can call the Augment() func

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

This will make sure that your inputs and masks are equally randomly augmented. [1]: https://www.tensorflow.org/tutorials/images/segmentation

Upvotes: -1

Stanislav D.
Stanislav D.

Reputation: 110

I solved this by using concat, to create one image and then using augmentation layers.

def augment_using_layers(images, mask, size=None):
    
    if size is None:
        h_s = mask.shape[0]
        w_s = mask.shape[1]
    else:
        h_s = size[0]
        w_s = size[1]
    
    def aug(height=h_s, width=w_s):

        flip = tf.keras.layers.RandomFlip(mode="horizontal")
        
        rota = tf.keras.layers.RandomRotation(0.2, fill_mode='constant')
        
        zoom = tf.keras.layers.RandomZoom(
                            height_factor=(-0.05, -0.15),
                            width_factor=(-0.05, -0.15)
                            )
        
        trans = tf.keras.layers.RandomTranslation(height_factor=(-0.1, 0.1),
                                            width_factor=(-0.1, 0.1), 
                                            fill_mode='constant')
        
        crop = tf.keras.layers.RandomCrop(h_s, w_s)
        
        layers = [flip, zoom, crop, trans, rota]
        aug_model = tf.keras.Sequential(layers)

        return aug_model
    
    aug = aug()
    
    mask = tf.stack([mask, mask, mask], -1)
    mask = tf.cast(mask, 'float32')

    images_mask = tf.concat([images, mask], -1)  
    images_mask = aug(images_mask)  
    
    image = images_mask[:,:,0:3]
    mask = images_mask[:,:,4]
    
    return image, tf.cast(mask, 'uint8')

Then you can map your dataset:

# create dataset
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.map(lambda x: load_dataset(x, (400, 400)))

# aug. dataset
dataset_aug = dataset.map(lambda x, y: augment_using_layers(x, y, (400, 400)))

Output: enter image description here

Upvotes: 1

Gowri Shankar p
Gowri Shankar p

Reputation: 41

Fixing a common seed will apply same augmentations to image and mask.

def Augment(tar_shape=(512,512), seed=37):
    img = tf.keras.Input(shape=(None,None,3))
    msk = tf.keras.Input(shape=(None,None,1))

    i = tf.keras.layers.RandomFlip(seed=seed)(img)
    m = tf.keras.layers.RandomFlip(seed=seed)(msk)
    i = tf.keras.layers.RandomTranslation((-0.75, 0.75),(-0.75, 0.75),seed=seed)(i)
    m = tf.keras.layers.RandomTranslation((-0.75, 0.75),(-0.75, 0.75),seed=seed)(m)
    i = tf.keras.layers.RandomRotation((-0.35, 0.35),seed=seed)(i)
    m = tf.keras.layers.RandomRotation((-0.35, 0.35),seed=seed)(m)
    i = tf.keras.layers.RandomZoom((-0.1, 0.05),(-0.1, 0.05),seed=seed)(i)
    m = tf.keras.layers.RandomZoom((-0.1, 0.05),(-0.1, 0.05),seed=seed)(m)
    i = tf.keras.layers.RandomCrop(tar_shape[0],tar_shape[1],seed=seed)(i)
    m = tf.keras.layers.RandomCrop(tar_shape[0],tar_shape[1],seed=seed)(m)
    
    return tf.keras.Model(inputs=(img,msk), outputs=(i,m))
Augment = Augment()

ds_train = ds_train.map(lambda img,msk: Augment((img,msk)), num_parallel_calls=AUTOTUNE)

Imp:

  • The above functions can change dtype of image and mask from int32/uint8 to float32.
  • And also the output-mask can contain values other than 0/1 (like 0.9987,...). This was due to interpolation. To overcome this you can change interpolation from bilinear to nearest.

Upvotes: 1

Tom
Tom

Reputation: 168

Here is my own implementation in case someone else wants to use tf built-ins (tf.image api) as of decembre 2020 :)

@tf.function
def load_image(datapoint, augment=True):
    
    # resize image and mask
    img_orig = input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))
    mask_orig = input_mask = tf.image.resize(datapoint['segmentation_mask'], (IMG_SIZE, IMG_SIZE))
    
    # rescale the image
    if IMAGE_CHANNELS == 1:
        input_image = tf.image.rgb_to_grayscale(input_image)
    input_image = tf.cast(input_image, tf.float32) / 255.0
    
    # augmentation
    if augment:
        # zoom in a bit
        if tf.random.uniform(()) > 0.5:
            # use original image to preserve high resolution
            input_image = tf.image.central_crop(img_orig, 0.75)
            input_mask = tf.image.central_crop(mask_orig, 0.75)
            # resize
            input_image = tf.image.resize(input_image, (IMG_SIZE, IMG_SIZE))
            input_mask = tf.image.resize(input_mask, (IMG_SIZE, IMG_SIZE))
        
        # random brightness adjustment illumination
        input_image = tf.image.random_brightness(input_image, 0.3)
        # random contrast adjustment
        input_image = tf.image.random_contrast(input_image, 0.2, 0.5)
        
        # flipping random horizontal or vertical
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_left_right(input_image)
            input_mask = tf.image.flip_left_right(input_mask)
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_up_down(input_image)
            input_mask = tf.image.flip_up_down(input_mask)

        # rotation in 30° steps
        rot_factor = tf.cast(tf.random.uniform(shape=[], maxval=12, dtype=tf.int32), tf.float32)
        angle = np.pi/12*rot_factor
        input_image = tfa.image.rotate(input_image, angle)
        input_mask = tfa.image.rotate(input_mask, angle)

    return input_image, input_mask

Upvotes: 4

B200011011
B200011011

Reputation: 4258

You can try with external libraries for extra image augmentations. These links may help for image augmentation along with segmentation mask,

albumentations

https://github.com/albumentations-team/albumentations

https://albumentations.ai/docs/getting_started/mask_augmentation/

enter image description here

imgaug

https://github.com/aleju/imgaug

https://nbviewer.jupyter.org/github/aleju/imgaug-doc/blob/master/notebooks/B05%20-%20Augment%20Segmentation%20Maps.ipynb

enter image description here

Upvotes: 2

Related Questions