mrgloom
mrgloom

Reputation: 21622

How to implement CRelu in Keras?

I'm trying to implement CRelu layer in Keras

One option that seems work is to use Lambda layer:

def _crelu(x):
    x = tf.nn.crelu(x, axis=-1)
    return x

def _conv_bn_crelu(x, n_filters, kernel_size):
    x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Lambda(_crelu)(x)
    return x

But I wonder is Lamda layer introduce some overhead in training or inference process?

My second attemp is to create keras layer that is wrapper around tf.nn.crelu

class CRelu(Layer):
    def __init__(self, **kwargs):
        super(CRelu, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CRelu, self).build(input_shape)

    def call(self, x):
        x = tf.nn.crelu(x, axis=-1)
        return x

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        output_shape[-1] = output_shape[-1] * 2
        output_shape = tuple(output_shape)
        return output_shape

def _conv_bn_crelu(x, n_filters, kernel_size):
    x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=-1)(x)
    x = CRelu()(x)
    return x

Which version will be more efficient?

Also looking forward for pure Keras implementation, if it's possible.

Upvotes: 2

Views: 503

Answers (1)

Samuele Cornell
Samuele Cornell

Reputation: 165

I don't think there is a significant difference between the two implementations speed-wise.

The Lambda implementation is the simplest actually but writing a custom Layer as you have done usually is better, especially for what regards model saving and loading (get_config method).

But in this case it doesn't matter as the CReLU is trivial and don't require saving and restoring parameters. You can store the axis parameter actually as in the code below. In this way it will be retrieved automatically when the model is loaded.

class CRelu(Layer):
    def __init__(self, axis=-1, **kwargs):
        self.axis = axis 
        super(CRelu, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CRelu, self).build(input_shape)

    def call(self, x):
        x = tf.nn.crelu(x, axis=self.axis)
        return x

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        output_shape[-1] = output_shape[-1] * 2
        output_shape = tuple(output_shape)
        return output_shape

    def get_config(self, input_shape):
        config = {'axis': self.axis, }
        base_config = super(CReLU, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Upvotes: 1

Related Questions