Nick Merrill
Nick Merrill

Reputation: 114

Custom Keras loss function that conditionally creates a zero gradient

My problem is I don't want the weights to be adjusted if y_true takes certain values. I do not want to simply remove those examples from training data because of the nature of the RNN I am trying to use.

Is there a way to write a conditional loss function in Keras with this behavior?

For example: if y_true is negative then apply zero gradient so that parameters in the model do not change, if y_true is positive loss = losses.mean_squared_error(y_true, y_pred).

Upvotes: 1

Views: 2270

Answers (2)

today
today

Reputation: 33410

You can define a custom loss function and simply use K.switch to conditionally get zero loss:

from keras import backend as K
from keras import losses

def custom_loss(y_true, y_pred):
    loss = losses.mean_squared_error(y_true, y_pred)
    return K.switch(K.flatten(K.equal(y_true, 0.)), K.zeros_like(loss), loss)

Test:

from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Dense(1, input_shape=(1,)))

model.compile(loss=custom_loss, optimizer='adam')

weights, bias = model.layers[0].get_weights()

x = np.array([1, 2, 3])
y = np.array([0, 0, 0])

model.train_on_batch(x, y)

# check if the parameters has not changed after training on the batch
>>> (weights == model.layers[0].get_weights()[0]).all()
True

>>> (bias == model.layers[0].get_weights()[1]).all()
True

Upvotes: 2

Gerges
Gerges

Reputation: 6509

Since the y's are in batches, you need to select those from the batch which are non-zero in the custom loss function

def myloss(y_true, y_pred):
    idx  = tf.not_equal(y_true, 0)
    y_true = tf.boolean_mask(y_true, idx)
    y_pred = tf.boolean_mask(y_pred, idx)
    return losses.mean_squared_error(y_true, y_pred)

Then it can be used as such:

model = keras.Sequential([Dense(32, input_shape=(2,)), Dense(1)])
model.compile('adam', loss=myloss)

x = np.random.randn(2, 2)
y = np.array([1, 0])
model.fit(x, y)

But you might need extra logic in the loss function in case all y_true in the batch were zero, in this case, the loss function can be modified as such:

def myloss2(y_true, y_pred):
    idx  = tf.not_equal(y_true, 0)
    y_true = tf.boolean_mask(y_true, idx)
    y_pred = tf.boolean_mask(y_pred, idx)
    loss = tf.cond(tf.equal(tf.shape(y_pred)[0], 0), lambda: tf.constant(0, dtype=tf.float32), lambda: losses.mean_squared_error(y_true, y_pred))
    return loss

Upvotes: 1

Related Questions