user79983
user79983

Reputation: 75

How to apply kernel regularization in a custom layer in Keras/TensorFlow?

Consider the following custom layer code from a TensorFlow tutorial:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

How do I apply any pre-defined regularization (say tf.keras.regularizers.L1) or custom regularization on the parameters of the custom layer?

Upvotes: 6

Views: 3838

Answers (1)

today
today

Reputation: 33420

The add_weight method takes a regularizer argument which you can use to apply regularization on the weight. For example:

self.kernel = self.add_weight("kernel",
                               shape=[int(input_shape[-1]), self.num_outputs],
                               regularizer=tf.keras.regularizers.l1_l2())

Alternatively, to have more control like other built-in layers, you can modify the definition of custom layer and add a kernel_regularizer argument to __init__ method:

from tensorflow.keras import regularizers

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs, kernel_regularizer=None):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs
    self.kernel_regularizer = regularizers.get(kernel_regularizer)

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs],
                                  regularizer=self.kernel_regularizer)

With that you can even pass a string like 'l1' or 'l2' to kernel_regularizer argument when constructing the layer, and it would be resolved properly.

Upvotes: 13

Related Questions