EnnessGe
EnnessGe

Reputation: 1

Tensorflow custom MSE loss function

I am writing a specialised version of mean squared error function where label has form

(x, y, w)

and w is weight. Model output has form

(x, y)

I want to penalise more for certain cases and indicate this using weight component w and also penalise more for the vertical part of an error with hard coded constants. My failed attempt at the function is:

def mod_point_loss_weighted(y_true_in, y_pred):
    import tensorflow as tf

    # cut out (x, y) part of label
    y_true = tf.gather(y_true_in, [0, 1], axis=1)

    # cut out weight part of label
    y_weight = tf.gather(y_true_in, [2], axis=1)

    y_true_cast = tf.cast(y_true, tf.float32)
    diff = y_pred - y_true_cast
    squared_diff = tf.square(diff)

    # penalise vertical error slightly more
    weighted_loss = tf.multiply(squared_diff, [0.4, 0.6])
    weighted_loss = tf.reduce_sum(weighted_loss, axis=1)

    # add custom weight 
    loss = tf.multiply(weighted_loss, y_weight)

    return loss

However, I get "Node: 'SquaredDifference' required broadcastable shapes". Any ideas what I am doing wrong - it seems y_true = tf.gather(y_true_in, [0, 1], axis=1) is somehow not right but I cannot see how. I use tensorflow v 2.10

The following is complete error trace (when trying to train):

2023-06-23 12:17:42.718546: W tensorflow/core/framework/op_kernel.cc:1768] INVALID_ARGUMENT: required broadcastable shapes
2023-06-23 12:17:42.718583: W tensorflow/core/framework/op_kernel.cc:1768] INVALID_ARGUMENT: required broadcastable shapes
Traceback (most recent call last):
  File "/home/enness/wqla/s3/train_topology.py", line 6, in <module>
    application.run("train:schema")
  File "/home/enness/wqla/s3/AuxBase/Application.py", line 103, in run
    context.execute(stage)
  File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 115, in execute
    self.__execute_training()
  File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 168, in __execute_training
    ModelTraining.train(self)
  File "/home/enness/wqla/s3/AuxBase/ModelTraining.py", line 132, in train
    history = model.fit(
  File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'mod_point_loss_weighted/sub' defined at (most recent call last):
    File "/home/enness/wqla/s3/train_topology.py", line 6, in <module>
      application.run("train:schema")
    File "/home/enness/wqla/s3/AuxBase/Application.py", line 103, in run
      context.execute(stage)
    File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 115, in execute
      self.__execute_training()
    File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 168, in __execute_training
      ModelTraining.train(self)
    File "/home/enness/wqla/s3/AuxBase/ModelTraining.py", line 132, in train
      history = model.fit(
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 994, in train_step
      loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1052, in compute_loss
      return self.compiled_loss(
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 265, in __call__
      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/losses.py", line 152, in __call__
      losses = call_fn(y_true, y_pred)
    File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/losses.py", line 272, in call
      return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/home/enness/wqla/s3/AuxBase/TrainingUtil.py", line 51, in mod_point_loss_weighted
      diff = y_pred - y_true_cast
Node: 'mod_point_loss_weighted/sub'
required broadcastable shapes
     [[{{node mod_point_loss_weighted/sub}}]] [Op:__inference_train_function_5472]
2023-06-23 12:17:43.033798: W tensorflow/core/kernels/data/generator_dataset_op.cc:108] Error occurred when finalizing GeneratorDataset iterator: FAILED_PRECONDITION: Python interpreter state is not initialized. The process may be terminated.
     [[{{node PyFunc}}]]

Upvotes: 0

Views: 133

Answers (1)

EnnessGe
EnnessGe

Reputation: 1

Here is the working version, the weight vector must have been transposed:

def mod_point_loss_weighted(y_true_in, y_pred):
    """
    Custom loss function penalising more for vertical component of point error and adding weight component
    """
    import tensorflow as tf
    y_true = tf.gather(y_true_in, [0, 1], axis=1)
    y_weight = tf.gather(y_true_in, [2], axis=1)

    y_true_cast = tf.cast(y_true, tf.float32)
    y_weight_cast = tf.cast(y_weight, tf.float32)
    y_weight_cast = tf.transpose(y_weight_cast)

    diff = y_pred - y_true_cast
    loss = tf.keras.backend.square(diff)
    loss = loss * [0.45, 0.55]
    loss = y_weight_cast * tf.keras.backend.sum(loss, axis=1)

    return loss[0]

Lessons learned:

  • Error message can be misleading about exact location of an error
  • Should debug loss function on small constant X's and Y's outside of training regime
  • Metric must be configured as well (and pay attention whether error is for loss or metric!)

Upvotes: 0

Related Questions