Brian Mark Anderson
Brian Mark Anderson

Reputation: 546

TensorFlow image operations for batches: same shift applied to entire batch

This is very similar to the question asked here: TensorFlow image operations for batches

however, the distinction is that I have a 3D image (medical imaging) and I wanted to apply random shifts across the entire 3D set. The current code will apply a random shift across EACH image, rather than apply the same shift across all of the images.

The goal also is that this be .map compatible

import tensorflow as tf

class ShiftImages(object):
    def __init__(self, keys=('image', 'mask'), channel_dimensions=(1, 1), fill_value=None, fill_mode="reflect", interpolation="bilinear",
                 seed=None, height_factor=0.0, width_factor=0.0, on_global_3D=True):
        """
        Args:
            height_factor: a float represented as fraction of value, or a tuple of
                size 2 representing lower and upper bound for shifting vertically.
                A negative value means shifting image up, while a positive value means
                shifting image down. For instance, `height_factor=(-0.2, 0.3)` results
                in an output shifted by a random amount in the range `[-20%, +30%]`.
                `height_factor=0.2` results in an output height shifted by a random
                amount in the range `[-20%, +20%]`.
            width_factor: a float represented as fraction of value, or a tuple of
                size 2 representing lower and upper bound for shifting horizontally.
                A negative value means shifting image left, while a positive value means
                shifting image right. When represented as a single positive float,
                this value is used for both the upper and lower bound.
            fill_mode: Points outside the boundaries of the input are filled
                according to the given mode. Available methods are `"constant"`,
                `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
            interpolation: Interpolation mode. Supported values: `"nearest"`,
                `"bilinear"`.
            seed: Integer. Used to create a random seed.
            fill_value: a float represents the value to be filled outside the
                boundaries when `fill_mode="constant"`.
        """
        self.height_factor = height_factor
        self.width_factor = width_factor
        self.interpolation = interpolation
        self.fill_value = fill_value
        self.fill_mode = fill_mode
        self.random_translation = tf.keras.layers.RandomTranslation(height_factor=height_factor,
                                                                    width_factor=width_factor,
                                                                    interpolation=interpolation,
                                                                    fill_value=fill_value, seed=seed,
                                                                    fill_mode=fill_mode)
        self.keys = keys
        self.channel_dimensions = channel_dimensions
        self.global_3D = on_global_3D

    def shift(self, image_features):
        # Extract the relevant images (e.g., 'image' and 'mask') from the dataset
        combine_images = [image_features[key] for key in self.keys]
        shift_image = tf.concat(combine_images, axis=-1)

        # Apply the random translation (shifting)
        shifted_image = self.random_translation(shift_image)

        # Split back into the original components and return as a dictionary
        start_dim = 0
        for key in self.keys:
            dimension = image_features[key].shape[-1]
            end_dim = start_dim + dimension
            new_image = shifted_image[..., start_dim:end_dim]
            start_dim += dimension
            image_features[key] = new_image
        return image_features

# Example of how to use this with a tf.data.Dataset

# Create a sample dataset with 'image' and 'mask' as keys
dataset = tf.data.Dataset.from_tensor_slices({
    'image': tf.random.uniform(shape=(10, 128, 128, 1)),  # 10 random images
    'mask': tf.random.uniform(shape=(10, 128, 128, 1))    # 10 random masks
})

# Instantiate the ShiftImages class
shift_images = ShiftImages(height_factor=0.1, width_factor=0.1)

# Apply the shift function to the dataset using the map method
shifted_dataset = dataset.map(lambda x: shift_images.shift(x))

Upvotes: 0

Views: 12

Answers (1)

Brian Mark Anderson
Brian Mark Anderson

Reputation: 546

Figured this out, at least in a way that works for what I'm wanting

The solution was to make a height and width translator. Then reshape the 3D image [10, 50, 50] -> [500, 50]. Apply the width translator. Reshape back to 3D [10, 50, 50]. Tranpose axis, reshape back to [500, 50], apply the 'height' translator. Reshape and transpose back

class ShiftImages(object):
    def __init__(self, keys=('image', 'mask'), channel_dimensions=(1, 1), fill_value=None, fill_mode="reflect", interpolation="bilinear",
                 seed=None, height_factor=0.0, width_factor=0.0, on_global_3D=True, image_shape=(32, 320, 320, 3)):
        """
    Args:
        height_factor: a float represented as fraction of value, or a tuple of
            size 2 representing lower and upper bound for shifting vertically. A
            negative value means shifting image up, while a positive value means
            shifting image down. When represented as a single positive float,
            this value is used for both the upper and lower bound. For instance,
            `height_factor=(-0.2, 0.3)` results in an output shifted by a random
            amount in the range `[-20%, +30%]`. `height_factor=0.2` results in
            an output height shifted by a random amount in the range
            `[-20%, +20%]`.
        width_factor: a float represented as fraction of value, or a tuple of
            size 2 representing lower and upper bound for shifting horizontally.
            A negative value means shifting image left, while a positive value
            means shifting image right. When represented as a single positive
            float, this value is used for both the upper and lower bound. For
            instance, `width_factor=(-0.2, 0.3)` results in an output shifted
            left by 20%, and shifted right by 30%. `width_factor=0.2` results
            in an output height shifted left or right by 20%.
        fill_mode: Points outside the boundaries of the input are filled
            according to the given mode. Available methods are `"constant"`,
            `"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
            - `"reflect"`: `(d c b a | a b c d | d c b a)`
                The input is extended by reflecting about the edge of the last
                pixel.
            - `"constant"`: `(k k k k | a b c d | k k k k)`
                The input is extended by filling all values beyond
                the edge with the same constant value k specified by
                `fill_value`.
            - `"wrap"`: `(a b c d | a b c d | a b c d)`
                The input is extended by wrapping around to the opposite edge.
            - `"nearest"`: `(a a a a | a b c d | d d d d)`
                The input is extended by the nearest pixel.
            Note that when using torch backend, `"reflect"` is redirected to
            `"mirror"` `(c d c b | a b c d | c b a b)` because torch does not
            support `"reflect"`.
            Note that torch backend does not support `"wrap"`.
        interpolation: Interpolation mode. Supported values: `"nearest"`,
            `"bilinear"`.
        seed: Integer. Used to create a random seed.
        fill_value: a float represents the value to be filled outside the
            boundaries when `fill_mode="constant"`.
        """
        self.og_shape = image_shape
        self.height_factor = height_factor
        self.width_factor = width_factor
        self.interpolation = interpolation
        self.fill_value = fill_value
        self.fill_mode = fill_mode
        """
        I know that this has height_factor equal to 0, do not change it! We reshape things later
        """
        self.random_translation_height = tf.keras.layers.RandomTranslation(height_factor=0.0,
                                                                           width_factor=height_factor,
                                                                           interpolation=interpolation,
                                                                           fill_value=fill_value, seed=seed,
                                                                           fill_mode=fill_mode)
        self.random_translation_width = tf.keras.layers.RandomTranslation(height_factor=0.0,
                                                                          width_factor=width_factor,
                                                                          interpolation=interpolation,
                                                                          fill_value=fill_value, seed=seed,
                                                                          fill_mode=fill_mode)
        self.keys = keys
        self.channel_dimensions = channel_dimensions
        self.global_3D = on_global_3D

    def parse(self, image_features, *args, **kwargs):
        _check_keys_(input_features=image_features, keys=self.keys)
        combine_images = [image_features[i] for i in self.keys]
        shift_image = tf.concat(combine_images, axis=-1)
        og_shape = self.og_shape
        if self.height_factor != 0.0:
            if self.global_3D:
                shift_image = tf.reshape(shift_image, [og_shape[0] * og_shape[1]] + [i for i in og_shape[2:]])
            shift_image = self.random_translation_width(shift_image)
            if self.global_3D:
                shift_image = tf.reshape(shift_image, og_shape)
        if self.width_factor != 0.0:
            if self.global_3D:
                shift_image = tf.reshape(tf.transpose(shift_image, [0, 2, 1, 3]), [og_shape[0] * og_shape[1], og_shape[2]] + [i for i in og_shape[3:]])
            shift_image = self.random_translation_width(shift_image)
            if self.global_3D:
                shift_image = tf.reshape(shift_image, og_shape)
                shift_image = tf.transpose(shift_image, [0, 2, 1, 3])
        shifted_image = shift_image
        start_dim = 0
        for key in self.keys:
            dimension = image_features[key].shape[-1]
            end_dim = start_dim + dimension
            new_image = shifted_image[..., start_dim:end_dim]
            start_dim += dimension
            image_features[key] = new_image
        return image_features

Main drawback is that a specific image size is required to ensure the stack isn't 'ragged'. Will look into adding depth-wise shifts too

Upvotes: 0

Related Questions