Reputation: 1054
I'm trying to run a custom function that accepts sample_weights. I'm following this documentation https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss.
However, when I try to use the following cost function:
class deltaE(Loss):
def __call__(self, y_true, y_pred, sample_weight):
errors = tf_get_deltaE2000(y_true * tf_Xtrain_labels_max, y_pred * tf_Xtrain_labels_max)
errors *= sample_weight
return tf.math.reduce_mean(errors, axis=-1)
loss_deltaE = deltaE()
I get this error on the Model.fit
method.
TypeError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:543 train_step **
self.compiled_metrics.update_state(y, y_pred, sample_weight)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/compile_utils.py:411 update_state
metric_obj.update_state(y_t, y_p)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/metrics_utils.py:90 decorated
update_op = update_state_fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py:603 update_state
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
TypeError: __call__() missing 1 required positional argument: 'sample_weight'
I'm using a generator that yields a tuple of length 3 just as required. I've checked that. That's working properly.
The cost function works fine too. When I use the code below, the model trains without problems.
def loss_deltaE(y_true, y_pred):
errors = tf_get_deltaE2000(y_true * tf_Xtrain_labels_max, y_pred * tf_Xtrain_labels_max)
return tf.math.reduce_mean(errors, axis=-1)
If someone has any clue. I'd appreciate it. Thanks in advance!
Upvotes: 3
Views: 1706
Reputation: 22021
this is a workaround to pass additional arguments to a custom loss function. the trick consists in using fake inputs which are useful to build and use the loss in the correct ways
I provide a dummy example in a regression problem
def mse(y_true, y_pred, sample_weight):
error = y_true-y_pred
return K.mean(K.sqrt(error)*sample_weight)
X = np.random.uniform(0,1, (1000,10))
y = np.random.uniform(0,1, 1000)
W = np.random.uniform(1,2, 1000)
inp = Input((10))
true = Input((1))
sample_weight = Input((1))
x = Dense(32, activation='relu')(inp)
out = Dense(1)(x)
m = Model([inp,true, sample_weight], out)
m.add_loss( mse( true, out, sample_weight ) )
m.compile(loss=None, optimizer='adam')
history = m.fit([X, y, W], y, epochs=10)
# final fitted model to compute predictions
final_m = Model(inp, out)
Upvotes: 2