Reputation: 1
I am writing a specialised version of mean squared error function where label has form
(x, y, w)
and w is weight. Model output has form
(x, y)
I want to penalise more for certain cases and indicate this using weight component w
and also penalise more for the vertical part of an error with hard coded constants. My failed attempt at the function is:
def mod_point_loss_weighted(y_true_in, y_pred):
import tensorflow as tf
# cut out (x, y) part of label
y_true = tf.gather(y_true_in, [0, 1], axis=1)
# cut out weight part of label
y_weight = tf.gather(y_true_in, [2], axis=1)
y_true_cast = tf.cast(y_true, tf.float32)
diff = y_pred - y_true_cast
squared_diff = tf.square(diff)
# penalise vertical error slightly more
weighted_loss = tf.multiply(squared_diff, [0.4, 0.6])
weighted_loss = tf.reduce_sum(weighted_loss, axis=1)
# add custom weight
loss = tf.multiply(weighted_loss, y_weight)
return loss
However, I get "Node: 'SquaredDifference' required broadcastable shapes". Any ideas what I am doing wrong - it seems y_true = tf.gather(y_true_in, [0, 1], axis=1)
is somehow not right but I cannot see how. I use tensorflow v 2.10
The following is complete error trace (when trying to train):
2023-06-23 12:17:42.718546: W tensorflow/core/framework/op_kernel.cc:1768] INVALID_ARGUMENT: required broadcastable shapes
2023-06-23 12:17:42.718583: W tensorflow/core/framework/op_kernel.cc:1768] INVALID_ARGUMENT: required broadcastable shapes
Traceback (most recent call last):
File "/home/enness/wqla/s3/train_topology.py", line 6, in <module>
application.run("train:schema")
File "/home/enness/wqla/s3/AuxBase/Application.py", line 103, in run
context.execute(stage)
File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 115, in execute
self.__execute_training()
File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 168, in __execute_training
ModelTraining.train(self)
File "/home/enness/wqla/s3/AuxBase/ModelTraining.py", line 132, in train
history = model.fit(
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node 'mod_point_loss_weighted/sub' defined at (most recent call last):
File "/home/enness/wqla/s3/train_topology.py", line 6, in <module>
application.run("train:schema")
File "/home/enness/wqla/s3/AuxBase/Application.py", line 103, in run
context.execute(stage)
File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 115, in execute
self.__execute_training()
File "/home/enness/wqla/s3/AuxBase/ContextImpl.py", line 168, in __execute_training
ModelTraining.train(self)
File "/home/enness/wqla/s3/AuxBase/ModelTraining.py", line 132, in train
history = model.fit(
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit
tmp_logs = self.train_function(iterator)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function
return step_function(self, iterator)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step
outputs = model.train_step(data)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 994, in train_step
loss = self.compute_loss(x, y, y_pred, sample_weight)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/training.py", line 1052, in compute_loss
return self.compiled_loss(
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 265, in __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/losses.py", line 152, in __call__
losses = call_fn(y_true, y_pred)
File "/home/enness/wqla/s3/venv/lib/python3.9/site-packages/keras/losses.py", line 272, in call
return ag_fn(y_true, y_pred, **self._fn_kwargs)
File "/home/enness/wqla/s3/AuxBase/TrainingUtil.py", line 51, in mod_point_loss_weighted
diff = y_pred - y_true_cast
Node: 'mod_point_loss_weighted/sub'
required broadcastable shapes
[[{{node mod_point_loss_weighted/sub}}]] [Op:__inference_train_function_5472]
2023-06-23 12:17:43.033798: W tensorflow/core/kernels/data/generator_dataset_op.cc:108] Error occurred when finalizing GeneratorDataset iterator: FAILED_PRECONDITION: Python interpreter state is not initialized. The process may be terminated.
[[{{node PyFunc}}]]
Upvotes: 0
Views: 133
Reputation: 1
Here is the working version, the weight vector must have been transposed:
def mod_point_loss_weighted(y_true_in, y_pred):
"""
Custom loss function penalising more for vertical component of point error and adding weight component
"""
import tensorflow as tf
y_true = tf.gather(y_true_in, [0, 1], axis=1)
y_weight = tf.gather(y_true_in, [2], axis=1)
y_true_cast = tf.cast(y_true, tf.float32)
y_weight_cast = tf.cast(y_weight, tf.float32)
y_weight_cast = tf.transpose(y_weight_cast)
diff = y_pred - y_true_cast
loss = tf.keras.backend.square(diff)
loss = loss * [0.45, 0.55]
loss = y_weight_cast * tf.keras.backend.sum(loss, axis=1)
return loss[0]
Lessons learned:
Upvotes: 0