Reputation: 546
This is very similar to the question asked here: TensorFlow image operations for batches
however, the distinction is that I have a 3D image (medical imaging) and I wanted to apply random shifts across the entire 3D set. The current code will apply a random shift across EACH image, rather than apply the same shift across all of the images.
The goal also is that this be .map compatible
import tensorflow as tf
class ShiftImages(object):
def __init__(self, keys=('image', 'mask'), channel_dimensions=(1, 1), fill_value=None, fill_mode="reflect", interpolation="bilinear",
seed=None, height_factor=0.0, width_factor=0.0, on_global_3D=True):
"""
Args:
height_factor: a float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound for shifting vertically.
A negative value means shifting image up, while a positive value means
shifting image down. For instance, `height_factor=(-0.2, 0.3)` results
in an output shifted by a random amount in the range `[-20%, +30%]`.
`height_factor=0.2` results in an output height shifted by a random
amount in the range `[-20%, +20%]`.
width_factor: a float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound for shifting horizontally.
A negative value means shifting image left, while a positive value means
shifting image right. When represented as a single positive float,
this value is used for both the upper and lower bound.
fill_mode: Points outside the boundaries of the input are filled
according to the given mode. Available methods are `"constant"`,
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
interpolation: Interpolation mode. Supported values: `"nearest"`,
`"bilinear"`.
seed: Integer. Used to create a random seed.
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode="constant"`.
"""
self.height_factor = height_factor
self.width_factor = width_factor
self.interpolation = interpolation
self.fill_value = fill_value
self.fill_mode = fill_mode
self.random_translation = tf.keras.layers.RandomTranslation(height_factor=height_factor,
width_factor=width_factor,
interpolation=interpolation,
fill_value=fill_value, seed=seed,
fill_mode=fill_mode)
self.keys = keys
self.channel_dimensions = channel_dimensions
self.global_3D = on_global_3D
def shift(self, image_features):
# Extract the relevant images (e.g., 'image' and 'mask') from the dataset
combine_images = [image_features[key] for key in self.keys]
shift_image = tf.concat(combine_images, axis=-1)
# Apply the random translation (shifting)
shifted_image = self.random_translation(shift_image)
# Split back into the original components and return as a dictionary
start_dim = 0
for key in self.keys:
dimension = image_features[key].shape[-1]
end_dim = start_dim + dimension
new_image = shifted_image[..., start_dim:end_dim]
start_dim += dimension
image_features[key] = new_image
return image_features
# Example of how to use this with a tf.data.Dataset
# Create a sample dataset with 'image' and 'mask' as keys
dataset = tf.data.Dataset.from_tensor_slices({
'image': tf.random.uniform(shape=(10, 128, 128, 1)), # 10 random images
'mask': tf.random.uniform(shape=(10, 128, 128, 1)) # 10 random masks
})
# Instantiate the ShiftImages class
shift_images = ShiftImages(height_factor=0.1, width_factor=0.1)
# Apply the shift function to the dataset using the map method
shifted_dataset = dataset.map(lambda x: shift_images.shift(x))
Upvotes: 0
Views: 12
Reputation: 546
Figured this out, at least in a way that works for what I'm wanting
The solution was to make a height and width translator. Then reshape the 3D image [10, 50, 50] -> [500, 50]. Apply the width translator. Reshape back to 3D [10, 50, 50]. Tranpose axis, reshape back to [500, 50], apply the 'height' translator. Reshape and transpose back
class ShiftImages(object):
def __init__(self, keys=('image', 'mask'), channel_dimensions=(1, 1), fill_value=None, fill_mode="reflect", interpolation="bilinear",
seed=None, height_factor=0.0, width_factor=0.0, on_global_3D=True, image_shape=(32, 320, 320, 3)):
"""
Args:
height_factor: a float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound for shifting vertically. A
negative value means shifting image up, while a positive value means
shifting image down. When represented as a single positive float,
this value is used for both the upper and lower bound. For instance,
`height_factor=(-0.2, 0.3)` results in an output shifted by a random
amount in the range `[-20%, +30%]`. `height_factor=0.2` results in
an output height shifted by a random amount in the range
`[-20%, +20%]`.
width_factor: a float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound for shifting horizontally.
A negative value means shifting image left, while a positive value
means shifting image right. When represented as a single positive
float, this value is used for both the upper and lower bound. For
instance, `width_factor=(-0.2, 0.3)` results in an output shifted
left by 20%, and shifted right by 30%. `width_factor=0.2` results
in an output height shifted left or right by 20%.
fill_mode: Points outside the boundaries of the input are filled
according to the given mode. Available methods are `"constant"`,
`"nearest"`, `"wrap"` and `"reflect"`. Defaults to `"constant"`.
- `"reflect"`: `(d c b a | a b c d | d c b a)`
The input is extended by reflecting about the edge of the last
pixel.
- `"constant"`: `(k k k k | a b c d | k k k k)`
The input is extended by filling all values beyond
the edge with the same constant value k specified by
`fill_value`.
- `"wrap"`: `(a b c d | a b c d | a b c d)`
The input is extended by wrapping around to the opposite edge.
- `"nearest"`: `(a a a a | a b c d | d d d d)`
The input is extended by the nearest pixel.
Note that when using torch backend, `"reflect"` is redirected to
`"mirror"` `(c d c b | a b c d | c b a b)` because torch does not
support `"reflect"`.
Note that torch backend does not support `"wrap"`.
interpolation: Interpolation mode. Supported values: `"nearest"`,
`"bilinear"`.
seed: Integer. Used to create a random seed.
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode="constant"`.
"""
self.og_shape = image_shape
self.height_factor = height_factor
self.width_factor = width_factor
self.interpolation = interpolation
self.fill_value = fill_value
self.fill_mode = fill_mode
"""
I know that this has height_factor equal to 0, do not change it! We reshape things later
"""
self.random_translation_height = tf.keras.layers.RandomTranslation(height_factor=0.0,
width_factor=height_factor,
interpolation=interpolation,
fill_value=fill_value, seed=seed,
fill_mode=fill_mode)
self.random_translation_width = tf.keras.layers.RandomTranslation(height_factor=0.0,
width_factor=width_factor,
interpolation=interpolation,
fill_value=fill_value, seed=seed,
fill_mode=fill_mode)
self.keys = keys
self.channel_dimensions = channel_dimensions
self.global_3D = on_global_3D
def parse(self, image_features, *args, **kwargs):
_check_keys_(input_features=image_features, keys=self.keys)
combine_images = [image_features[i] for i in self.keys]
shift_image = tf.concat(combine_images, axis=-1)
og_shape = self.og_shape
if self.height_factor != 0.0:
if self.global_3D:
shift_image = tf.reshape(shift_image, [og_shape[0] * og_shape[1]] + [i for i in og_shape[2:]])
shift_image = self.random_translation_width(shift_image)
if self.global_3D:
shift_image = tf.reshape(shift_image, og_shape)
if self.width_factor != 0.0:
if self.global_3D:
shift_image = tf.reshape(tf.transpose(shift_image, [0, 2, 1, 3]), [og_shape[0] * og_shape[1], og_shape[2]] + [i for i in og_shape[3:]])
shift_image = self.random_translation_width(shift_image)
if self.global_3D:
shift_image = tf.reshape(shift_image, og_shape)
shift_image = tf.transpose(shift_image, [0, 2, 1, 3])
shifted_image = shift_image
start_dim = 0
for key in self.keys:
dimension = image_features[key].shape[-1]
end_dim = start_dim + dimension
new_image = shifted_image[..., start_dim:end_dim]
start_dim += dimension
image_features[key] = new_image
return image_features
Main drawback is that a specific image size is required to ensure the stack isn't 'ragged'. Will look into adding depth-wise shifts too
Upvotes: 0