Ian_SP
Ian_SP

Reputation: 13

Changing activation function during training in Tensorflow

I’m currently trying to figure out a way to change my activation function(s) during training (Tensorflow 2). For example consider a leaky relu max(x, x/a) with a = 1 at the start of the training and as the training progresses i want to increase ‘a’ slowly.

What I’ve tried so far (via a callback) (minimal example):

class change_act_func(ke.callbacks.Callback):

    def __init__(self, model):
        self.model = model
    
    def on_epoch_begin(self, epoch, logs):
        if epoch == 2:
            def my_l_relu(x):
                a = 3
                return tf.math.maximum(x, x/a)
    
            self.model.layers[1].activation = my_l_relu

This doesn’t seem to work though. Any hints/ideas? Many thanks in advance : )

PS.: Here is an interactive example of the described activation function https://www.desmos.com/calculator/mawjdaom0f

Upvotes: 0

Views: 47

Answers (1)

Ian_SP
Ian_SP

Reputation: 13

I think i have a good enough solution. For anyone interested here is a minimal example (Tensorflow 2.15.0):

#################################

class My_LeakyReLU(Layer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def build(self, input_shape):
       self.alpha = self.add_weight(initializer='ones', trainable=False, name='alpha')

    def call(self, inputs):
        return tf.math.maximum(inputs, tf.math.multiply(inputs, self.alpha))

    def get_config(self):
        config = {"alpha": float(self.alpha)}
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

#################################

class change_activation(ke.callbacks.Callback):

    def __init__(self, model):
        self.model = model

    def on_epoch_begin(self, epoch, logs):
        neg_slope_act = 1.0/epoch
        self.model.layers[42].alpha.assign(neg_slope_act)
        # where 42 is the index of the "My_LeakyReLU"-Layer

#################################

Upvotes: 0

Related Questions