Florian Leicher
Florian Leicher

Reputation: 21

Keras custom Loss function with two inputs

I was not able to find anyone with this problem so I am asking. I want to implement a custom loss function in Keras/Tensorflow that treats different columns of the y values differently. I read that this was not possible out of the box as the Keras loss function takes only two parameters y_pred and y_true.

Yet I tried to split up those to two inputs to achieve what I want.

def noise_loss_fct9(y_true, y_pred):
  lmbd = 1.0
  border = 9
  y_true_measurement = y_true[:, :border]
  y_true_process = y_true[:, border:]
  y_pred_measurement = y_pred[:, :border]
  y_pred_process = y_pred[:, border:]

  error_measurement = K.mean(K.square(
      y_true_measurement - y_pred_measurement
  ), axis=-1)
  error_process = K.mean(K.square(
      y_true_process - y_pred_process
  ), axis=-1)

  return error_measurement + lmbd * error_process

Even though this loss function does get compiled in the log it shows nan as a loss.

Epoch 1/10
95s - loss: nan
Epoch 2/10
87s - loss: nan

Does this mean this is not a valid way at all (even though the model compiles) or is it just not showing the loss for some reason? What else should I do if this is not

I am grateful for any comments.

Upvotes: 2

Views: 1223

Answers (2)

Daniel Möller
Daniel Möller

Reputation: 86600

A loss function must return a "number" (or perhaps a tensor with just one number), not a tensor with many numbers.

When you use "axis" in the "k.mean", you're keeping it as a tensor with many entries.

Try "axis=None" or simply remove the axis parameter.


It's possible that the problem is coming from earlier in the model. To be sure, try your model with a regular "mse" loss function first, just in case.

Upvotes: 1

Jessica Alan
Jessica Alan

Reputation: 728

Did you mean two outputs?

What you seem to be trying to do is certainly possible. The fact that you're getting a nan suggests that there is something wrong with the way you are attempting to implement it. I recommend running print(math.isnan()) at certain points in your loss function to see where the error is coming from. Can you provide more details on what you're attempting to do?

Upvotes: 0

Related Questions