eymerich92
eymerich92

Reputation: 43

Is there a way to build a keras preprocessing layer that randomly rotates at specified angles?

I'm working on an astronomical images classification project and I'm currently using keras to build CNNs.

I'm trying to build a preprocessing pipeline to augment my dataset with keras/tensorflow layers. To keep things simple I would like to implement random transformations of the dihedral group (i.e., for square images, 90-degrees rotations and flips), but it seems that tf.keras.preprocessing.image.random_rotation only allows a random degree over a continuous range of choice following a uniform distribution.

I was wondering whether there is a way to instead choose from a list of specified degrees, in my case [0, 90, 180, 270].

Upvotes: 0

Views: 1256

Answers (1)

Lescurel
Lescurel

Reputation: 11651

Fortunately, there is a tensorflow function that does what you want : tf.image.rot90. The next step is to wrap that function into a custom PreprocessingLayer, so it does it randomly.

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers.experimental.preprocessing import PreprocessingLayer

class RandomRot90(PreprocessingLayer):
    def __init__(self, name=None, **kwargs) -> None:
        super(RandomRot90, self).__init__(name=name, **kwargs)
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
    
    def call(self, inputs, training=True):
        if training is None:
            training = K.learning_phase()
        
        def random_rot90():
            # random int between 0 and 3
            rot = tf.random.uniform((),0,4, dtype=tf.int32)
            return tf.image.rot90(inputs, k=rot)
        
        # if not training, do nothing
        outputs = tf.cond(training, random_rot90, lambda:inputs)
        outputs.set_shape(inputs.shape)
        return outputs
    
    def compute_output_shape(self, input_shape):
        return input_shape
  • Note that you might want to implements get_config if you want to be able to save and load a model with that layer. (See the documentation)
  • Note also that this layer might fail if your inputs are not square (height != width).

Upvotes: 6

Related Questions