Reputation: 21622
I'm trying to implement CRelu layer in Keras
One option that seems work is to use Lambda layer:
def _crelu(x):
x = tf.nn.crelu(x, axis=-1)
return x
def _conv_bn_crelu(x, n_filters, kernel_size):
x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
x = BatchNormalization(axis=-1)(x)
x = Lambda(_crelu)(x)
return x
But I wonder is Lamda layer introduce some overhead in training or inference process?
My second attemp is to create keras layer that is wrapper around tf.nn.crelu
class CRelu(Layer):
def __init__(self, **kwargs):
super(CRelu, self).__init__(**kwargs)
def build(self, input_shape):
super(CRelu, self).build(input_shape)
def call(self, x):
x = tf.nn.crelu(x, axis=-1)
return x
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
output_shape[-1] = output_shape[-1] * 2
output_shape = tuple(output_shape)
return output_shape
def _conv_bn_crelu(x, n_filters, kernel_size):
x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
x = BatchNormalization(axis=-1)(x)
x = CRelu()(x)
return x
Which version will be more efficient?
Also looking forward for pure Keras implementation, if it's possible.
Upvotes: 2
Views: 503
Reputation: 165
I don't think there is a significant difference between the two implementations speed-wise.
The Lambda implementation is the simplest actually but writing a custom Layer as you have done usually is better, especially for what regards model saving and loading (get_config method).
But in this case it doesn't matter as the CReLU is trivial and don't require saving and restoring parameters. You can store the axis parameter actually as in the code below. In this way it will be retrieved automatically when the model is loaded.
class CRelu(Layer):
def __init__(self, axis=-1, **kwargs):
self.axis = axis
super(CRelu, self).__init__(**kwargs)
def build(self, input_shape):
super(CRelu, self).build(input_shape)
def call(self, x):
x = tf.nn.crelu(x, axis=self.axis)
return x
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
output_shape[-1] = output_shape[-1] * 2
output_shape = tuple(output_shape)
return output_shape
def get_config(self, input_shape):
config = {'axis': self.axis, }
base_config = super(CReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Upvotes: 1