Ken Lin
Ken Lin

Reputation: 23

Set only the bias to be non-trainable in TensorFlow Keras

When training Neural Networks for classification in TensorFlow/Keras, is it possible to set the bias term in the output layer to non-trainable?

It looks like layer.trainable = False will freeze both the kernel and the bias in this layer. Is it possible to only freeze the bias, but still update the kernel?

Upvotes: 0

Views: 3093

Answers (2)

Josaph
Josaph

Reputation: 86

A hacky solution with initializers and constraints for uninitialized models.

If your model has been initialized, you will need to replace the layers in order to add initializers and constraints. See https://github.com/keras-team/keras/issues/13100.

Biases with different values

class ConstantTensorInitializer(tf.keras.initializers.Initializer):
  """Initializes tensors to `t`."""

  def __init__(self, t):
    self.t = t

  def __call__(self, shape, dtype=None):
    return self.t

  def get_config(self):
    return {'t': self.t}

class ConstantTensorConstraint(tf.keras.constraints.Constraint):
  """Constrains tensors to `t`."""

  def __init__(self, t):
    self.t = t

  def __call__(self, w):
    return self.t

  def get_config(self):
    return {'t': self.t}


#Example
biases = tf.constant([0.1, 0.2, 0.3, 0.4])
layer = Conv2D(
  4,
  (3, 3),
  use_bias=True,
  bias_initializer=ConstantTensorInitializer(biases),
  bias_constraint=ConstantTensorConstraint(biases)
)

Biases with the same value

class ConstantValueConstraint(tf.keras.constraints.Constraint):
  """Constrains the elements of the tensor to `value`."""

  def __init__(self, value):
    self.value = value

  def __call__(self, w):
    return w * 0 + self.value

  def get_config(self):
    return {'value': self.value}


#Example
layer = Conv2D(
  4,
  (3, 3),
  use_bias=True,
  bias_initializer=tf.keras.initializers.Constant(0.1), 
  bias_constraint=ConstantValueConstraint(0.1)
)

Upvotes: 3

CrazyBrazilian
CrazyBrazilian

Reputation: 1070

you can set use_bias to false for any layer.

ie.

model.add(layers.Conv2D(64, (3, 3), use_bias=False))

Upvotes: 0

Related Questions