Reputation: 129
I wanted to modify the following keras mean squared error loss (MSE) such that the loss is only computed sparsely.
def mean_squared_error(y_true, y_pred):
return K.mean(K.square(y_pred - y_true), axis=-1)
My output y
is a 3 channel image, where the 3rd channel is non-zero at only those pixels where loss is to be computed. Any idea how can I modify the above to compute sparse loss?
Upvotes: 5
Views: 2112
Reputation: 163
This is not the exact loss you are looking for, but I hope it will give you a hint to write your function (see also here for a Github discussion):
def masked_mse(mask_value):
def f(y_true, y_pred):
mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
masked_squared_error = K.square(mask_true * (y_true - y_pred))
masked_mse = (K.sum(masked_squared_error, axis=-1) /
K.sum(mask_true, axis=-1))
return masked_mse
f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
return f
The function computes the MSE loss over all the values of the predicted output, except for those elements whose corresponding value in the true output is equal to a masking value (e.g. -1).
Two notes:
K.mean(masked_squared_error, axis=1)
and I'm
instead averaging manually.np.nan
or np.inf
will not do the job), which means that you'll have to adapt your data so that it does not contain the mask_value
.In this example, the target output is always [1, 1, 1, 1]
, but some prediction values are progressively masked.
y_pred = K.constant([[ 1, 1, 1, 1],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
[ 1, 1, 1, 1],
[-1, 1, 1, 1],
[-1,-1, 1, 1],
[-1,-1,-1, 1],
[-1,-1,-1,-1]])
true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))
for i in range(true.shape[0]):
print(true[i], pred[i], loss[i], sep='\t')
The expected output is:
[ 1. 1. 1. 1.] [ 1. 1. 1. 1.] 0.0
[ 1. 1. 1. 1.] [ 1. 1. 1. 3.] 1.0
[-1. 1. 1. 1.] [ 1. 1. 1. 3.] 1.33333
[-1. -1. 1. 1.] [ 1. 1. 1. 3.] 2.0
[-1. -1. -1. 1.] [ 1. 1. 1. 3.] 4.0
[-1. -1. -1. -1.] [ 1. 1. 1. 3.] nan
Upvotes: 8
Reputation: 428
To prevent nan
from showing up, follow the instructions here. The following assumes you want the masked value (background) to be equal to zero:
# Copied almost character-by-character (only change is default mask_value=0)
# from https://github.com/keras-team/keras/issues/7065#issuecomment-394401137
def masked_mse(mask_value=0):
"""
Made default mask_value=0; not sure this is necessary/helpful
"""
def f(y_true, y_pred):
mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
masked_squared_error = K.square(mask_true * (y_true - y_pred))
# in case mask_true is 0 everywhere, the error would be nan, therefore divide by at least 1
# this doesn't change anything as where sum(mask_true)==0, sum(masked_squared_error)==0 as well
masked_mse = K.sum(masked_squared_error, axis=-1) / K.maximum(K.sum(mask_true, axis=-1), 1)
return masked_mse
f.__name__ = str('Masked MSE (mask_value={})'.format(mask_value))
return f
Upvotes: 1