Reputation: 258
I'm working on a multi-label classification problem where instead of each target index representing a distinct class, it represents some amount of time into the future. On top of wanting my predicted label to match the target label, I want an extra term to enforce some temporal aspect of the learning.
E.g.:
y_true = [1., 1., 1., 0.]
y_pred = [0.75, 0.81, 0.93, 0.65]
Above, the truth label implies something occurring during the first three indices.
I want to easily be able to mix and match loss functions.
I have a couple custom loss functions for overall accuracy, each wrapped within functions for adjustable arguments:
def weighted_binary_crossentropy(pos_weight=1):
def weighted_binary_crossentropy_(Y_true, Y_pred):
...
return tf.reduce_mean(loss, axis=-1)
return weighted_binary_crossentropy_
def mean_squared_error(threshold=0.5):
def mean_squared_error_(Y_true, Y_pred):
...
return tf.reduce_mean(loss, axis=-1)
return mean_squared_error
I also have a custom loss function to enforce the predicted label ending at the same time as the truth label (I haven't made use of the threshold
argument here yet):
def end_time_error(threshold=0.5):
def end_time_error_(Y_true, Y_pred):
_, n_times = K.int_shape(Y_true)
weights = K.arange(1, n_times + 1, dtype=float)
Y_true = tf.multiply(Y_true, weights)
argmax_true = K.argmax(Y_true, axis=1)
argmax_pred = K.argmax(Y_pred, axis=1)
loss = tf.math.squared_difference(argmax_true, argmax_pred)
return tf.reduce_mean(loss, axis=-1)
Sometimes I might want to combine end_time_error
with weighted_binary_crossentropy
, sometimes with mean_squared_error
, and I have plenty of other loss functions to experiment with. I don't want to have to code a new combined loss function for each pair.
I've tried making a meta-loss function that combines loss functions (globally defined in the same script).
def combination_loss(loss_dict, combine='add', weights=[]):
losses = []
if not weights:
weights = [1] * len(loss_dict)
for (loss_func, loss_args), weight in zip(loss_dict.items(), weights):
assert loss_func in globals().keys()
loss_func = eval(loss_func)
loss = loss_func(loss_args)
losses.append(loss * weight)
if combine == 'add':
loss = sum(losses)
elif combine == 'multiply':
loss = np.prod(losses)
return loss
To use this:
loss_args = {'loss_dict':
{'weighted_binary_crossentropy': {'pos_weight': 1},
'end_time_error': {}},
'combine': 'add',
'weights': [0.75, 0.25]}
model.compile(loss=combination_loss(**loss_args), ...)
Error:
File "C:\...\losses.py", line 165, in combination_loss
losses.append(loss * weight)
TypeError: unsupported operand type(s) for *: 'function' and 'float'
I'm playing loose with functions, so I'm not surprised this failed. But I'm not sure how to get what I want.
How can I combine functions with weights in combination_loss
?
Or should I be doing that directly in the model.compile()
call using a lambda function?
Ditching combination_loss
:
losses = []
for loss_, loss_args_ in loss_args['loss_dict'].items():
losses.append(get_loss(loss_)(**loss_args_))
loss = lambda y_true, y_pred: [l(y_true, y_pred) * w for l, w
in zip(losses, loss_args['weights'])]
model.compile(loss=loss, ...)
Error:
File "C:\...\losses.py", line 139, in end_time_error_
weights = K.arange(1, n_times + 1, dtype=float)
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'
Probably because y_true, y_pred
won't work as arguments for wrapped loss functions.
Upvotes: 0
Views: 898
Reputation: 10995
Let's simplify your use case for only two losses:
loss = alpha * loss1 + (1-alpha) * loss2
Then you can do:
def generate_loss(alpha):
def combination_loss(y_true, y_pred):
return alpha * loss1(y_true, y_pred) + (1-alpha) * loss2(y_true, y_pred)
return combination_loss
Obviously, loss1
and loss2
would be your respective loss functions.
You can use this to generate different loss functions for different alphas:
alpha = 0.7
combination_loss = generate_loss(alpha)
model.compile(loss=combination_loss, ...)
If alpha is supposed to be static, you can also get rid of the outer function generate_loss
.
Finally, you can also define this as a lambda function:
model.compile(loss=lambda y_true, y_pred: alpha * loss1(y_true, y_pred) + (1-alpha) * loss2(y_true, y_pred), ...)
I'm not sure where your bug is (I assume it's the eval
but I can't debug it) but if you simplify it enough like this or use this as a working example to introduce your losses and weights, it should work.
Upvotes: 1