vampiretap
vampiretap

Reputation: 361

How to implement a weighted mean squared error function in Keras

I am defining a weighted mean squared error in Keras as follows:

def weighted_mse(yTrue,yPred):
    data_weights = [w0,w1,w2,w3]
    data_weights_np = np.asarray(data_weights, np.float32)
    weights = tf.convert_to_tensor(data_weights_np, np.float32)
    return K.mean(weights*K.square(yTrue-yPred))

I have a list of weights for each prediction. The predictions are of shape for example: (25,4). That is generated via final dense layer with dimension 4. I wish to weights these prediction in the mean squared error, so I generate a tensor and multiply it with the sum of squares error. Is this the correct way to do so? Because, when I print the shape of the tensor, using tf.shape for YTrue and YPred it shows: Tensor("loss_19/dense_20_loss/Shape:0", shape=(3,), dtype=int32)

and for weights:

Tensor("loss_19/dense_20_loss/Shape_2:0", shape=(1,), dtype=int32)

Upvotes: 4

Views: 3716

Answers (1)

nuric
nuric

Reputation: 11225

The Keras API already provides a mechanism to provide weights, for example the model.fit function. From the documentation:

class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

sample_weight: Optional Numpy array of weights for the training samples, used for weighting the loss function (during training only). You can either pass a flat (1D) Numpy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile().

If you have a weight for each sample, you can pass the NumPy array as sample_weight to achieve the same effect without writing your own loss function.

Upvotes: 4

Related Questions