Spuds
Spuds

Reputation: 258

Dynamically combining loss functions in tensorflow keras

I'm working on a multi-label classification problem where instead of each target index representing a distinct class, it represents some amount of time into the future. On top of wanting my predicted label to match the target label, I want an extra term to enforce some temporal aspect of the learning.

E.g.:

y_true = [1., 1., 1., 0.]
y_pred = [0.75, 0.81, 0.93, 0.65]

Above, the truth label implies something occurring during the first three indices.

I want to easily be able to mix and match loss functions.

I have a couple custom loss functions for overall accuracy, each wrapped within functions for adjustable arguments:

def weighted_binary_crossentropy(pos_weight=1):
    def weighted_binary_crossentropy_(Y_true, Y_pred):
        ...
        return tf.reduce_mean(loss, axis=-1)
    return weighted_binary_crossentropy_

def mean_squared_error(threshold=0.5):
    def mean_squared_error_(Y_true, Y_pred):
        ...
        return tf.reduce_mean(loss, axis=-1)
    return mean_squared_error

I also have a custom loss function to enforce the predicted label ending at the same time as the truth label (I haven't made use of the threshold argument here yet):

def end_time_error(threshold=0.5):
    def end_time_error_(Y_true, Y_pred):
        _, n_times = K.int_shape(Y_true)
        weights = K.arange(1, n_times + 1, dtype=float)
        Y_true = tf.multiply(Y_true, weights)
        argmax_true = K.argmax(Y_true, axis=1)
        argmax_pred = K.argmax(Y_pred, axis=1)
        loss = tf.math.squared_difference(argmax_true, argmax_pred)
        return tf.reduce_mean(loss, axis=-1)

Sometimes I might want to combine end_time_error with weighted_binary_crossentropy, sometimes with mean_squared_error, and I have plenty of other loss functions to experiment with. I don't want to have to code a new combined loss function for each pair.

Attempt at solution 1

I've tried making a meta-loss function that combines loss functions (globally defined in the same script).

def combination_loss(loss_dict, combine='add', weights=[]):
    losses = []
    if not weights:
        weights = [1] * len(loss_dict)
    for (loss_func, loss_args), weight in zip(loss_dict.items(), weights):
        assert loss_func in globals().keys()
        loss_func = eval(loss_func)
        loss = loss_func(loss_args)
        losses.append(loss * weight)
    if combine == 'add':
        loss = sum(losses)
    elif combine == 'multiply':
        loss = np.prod(losses)
    return loss

To use this:

loss_args = {'loss_dict':
                 {'weighted_binary_crossentropy': {'pos_weight': 1},
                  'end_time_error': {}},
             'combine': 'add',
             'weights': [0.75, 0.25]}
model.compile(loss=combination_loss(**loss_args), ...)

Error:

  File "C:\...\losses.py", line 165, in combination_loss
    losses.append(loss * weight)

TypeError: unsupported operand type(s) for *: 'function' and 'float'

I'm playing loose with functions, so I'm not surprised this failed. But I'm not sure how to get what I want.

How can I combine functions with weights in combination_loss?

Or should I be doing that directly in the model.compile() call using a lambda function?

--EDIT

Attempt at solution 2

Ditching combination_loss:

losses = []
for loss_, loss_args_ in loss_args['loss_dict'].items():
    losses.append(get_loss(loss_)(**loss_args_))
loss = lambda y_true, y_pred: [l(y_true, y_pred) * w for l, w
                               in zip(losses, loss_args['weights'])]
model.compile(loss=loss, ...)

Error:

  File "C:\...\losses.py", line 139, in end_time_error_
    weights = K.arange(1, n_times + 1, dtype=float)

TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

Probably because y_true, y_pred won't work as arguments for wrapped loss functions.

Upvotes: 0

Views: 898

Answers (1)

runDOSrun
runDOSrun

Reputation: 10995

Let's simplify your use case for only two losses:

loss = alpha * loss1 + (1-alpha) * loss2

Then you can do:

def generate_loss(alpha):
    def combination_loss(y_true, y_pred):
        return alpha * loss1(y_true, y_pred) + (1-alpha) * loss2(y_true, y_pred)
    return combination_loss

Obviously, loss1 and loss2 would be your respective loss functions. You can use this to generate different loss functions for different alphas:

alpha = 0.7
combination_loss = generate_loss(alpha)
model.compile(loss=combination_loss, ...)

If alpha is supposed to be static, you can also get rid of the outer function generate_loss.

Finally, you can also define this as a lambda function:

model.compile(loss=lambda y_true, y_pred: alpha * loss1(y_true, y_pred) + (1-alpha) * loss2(y_true, y_pred), ...)

I'm not sure where your bug is (I assume it's the eval but I can't debug it) but if you simplify it enough like this or use this as a working example to introduce your losses and weights, it should work.

Upvotes: 1

Related Questions