delhi_loafer
delhi_loafer

Reputation: 129

How to implement sparse mean squared error loss in Keras

I wanted to modify the following keras mean squared error loss (MSE) such that the loss is only computed sparsely.

def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1)

My output y is a 3 channel image, where the 3rd channel is non-zero at only those pixels where loss is to be computed. Any idea how can I modify the above to compute sparse loss?

Upvotes: 5

Views: 2112

Answers (2)

baldassarreFe
baldassarreFe

Reputation: 163

This is not the exact loss you are looking for, but I hope it will give you a hint to write your function (see also here for a Github discussion):

def masked_mse(mask_value):
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        masked_mse = (K.sum(masked_squared_error, axis=-1) /
                      K.sum(mask_true, axis=-1))
        return masked_mse
    f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
    return f

The function computes the MSE loss over all the values of the predicted output, except for those elements whose corresponding value in the true output is equal to a masking value (e.g. -1).

Two notes:

  • when computing the mean the denominator must be the count of non-masked values and not the dimension of the array, that's why I'm not using K.mean(masked_squared_error, axis=1) and I'm instead averaging manually.
  • the masking value must be a valid number (i.e. np.nan or np.inf will not do the job), which means that you'll have to adapt your data so that it does not contain the mask_value.

In this example, the target output is always [1, 1, 1, 1], but some prediction values are progressively masked.

y_pred = K.constant([[ 1, 1, 1, 1], 
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3],
                     [ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
                     [ 1, 1, 1, 1],
                     [-1, 1, 1, 1],
                     [-1,-1, 1, 1],
                     [-1,-1,-1, 1],
                     [-1,-1,-1,-1]])

true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))

for i in range(true.shape[0]):
    print(true[i], pred[i], loss[i], sep='\t')

The expected output is:

[ 1.  1.  1.  1.]  [ 1.  1.  1.  1.]  0.0
[ 1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.0
[-1.  1.  1.  1.]  [ 1.  1.  1.  3.]  1.33333
[-1. -1.  1.  1.]  [ 1.  1.  1.  3.]  2.0
[-1. -1. -1.  1.]  [ 1.  1.  1.  3.]  4.0
[-1. -1. -1. -1.]  [ 1.  1.  1.  3.]  nan

Upvotes: 8

tdMJN6B2JtUe
tdMJN6B2JtUe

Reputation: 428

To prevent nan from showing up, follow the instructions here. The following assumes you want the masked value (background) to be equal to zero:

 # Copied almost character-by-character (only change is default mask_value=0)
 # from https://github.com/keras-team/keras/issues/7065#issuecomment-394401137
 def masked_mse(mask_value=0):
    """
    Made default mask_value=0; not sure this is necessary/helpful
    """
    def f(y_true, y_pred):
        mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
        masked_squared_error = K.square(mask_true * (y_true - y_pred))
        # in case mask_true is 0 everywhere, the error would be nan, therefore divide by at least 1
        # this doesn't change anything as where sum(mask_true)==0, sum(masked_squared_error)==0 as well
        masked_mse = K.sum(masked_squared_error, axis=-1) / K.maximum(K.sum(mask_true, axis=-1), 1)
        return masked_mse
    f.__name__ = str('Masked MSE (mask_value={})'.format(mask_value))
    return f

Upvotes: 1

Related Questions