VincFort
VincFort

Reputation: 1180

Tensorflow masked loss function wrong inputs

I am looking to implement a custom loss function, where when the y_true == -np.inf for a particular value, the loss is equal to 0. If the y_true != -np.inf, the mean absolute error is calculated.

Here is what I have:

def mae_masked(y_true, y_pred):
    y_true_masked = tf.cast(tf.boolean_mask(y_true, tf.math.is_inf(-y_true)),dtype=y_true.dtype)
    y_pred_masked = tf.cast(tf.boolean_mask(y_pred, tf.math.is_inf(-y_true)),dtype=y_pred.dtype)
    y_true_masked *= y_true
    y_pred_masked *= y_pred
    return tf.keras.losses.MeanAbsoluteError(y_true_masked, y_pred_masked)

When fitting the model, I get the following error regarding the MeanAbsoluteError function:

TypeError: Expected float32 passed to parameter 'y' of op 'Equal', got 'auto' of type 'str' instead. Error: Expected float32, got 'auto' of type 'str' instead.

I don't understand this error, since both of y_true and y_pred are of type float32.

When I print the inputs to the MeanAbsoluteError function, I get the following Tensor("mae_masked/mul_1:0", shape=(None, 4), dtype=float32) Tensor("mae_masked/mul:0", shape=(None, 4), dtype=float32), indicating both inputs are floats.

Upvotes: 0

Views: 96

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36624

I'm not sure why you're casting. Also, you're not using tf.keras.MeanAbsoluteError correctly. You're passing y_pred and y_true while instantiating the object. You should rather use tf.keras.metrics.mean_absolute_error or tf.keras.metrics.MeanAbsoluteError()(y_true, y_pred).

Either way, this would work and seems simpler:

import tensorflow as tf
import numpy as np


def mae_masked(y_true, y_pred):
    y_true_masked = tf.where(tf.math.is_inf(-y_true), 0, y_true)
    y_pred_masked = tf.where(tf.math.is_inf(-y_pred), 0, y_pred)
    return tf.keras.losses.mean_absolute_error(y_true_masked, y_pred_masked)


mae_masked(tf.convert_to_tensor([[1.], [2.], [3.], [np.inf]]),
           tf.convert_to_tensor([[4.], [7.], [np.inf], [2.]]))  # should be 3, 5, 3, 2
<tf.Tensor: shape=(4,), dtype=float32, numpy=array([3., 5., 3., 2.], dtype=float32)>

Upvotes: 1

Related Questions