Rajdeep Dutta
Rajdeep Dutta

Reputation: 1036

'Reduction' parameter in tf.keras.losses

According to the docs, the Reduction parameter takes on 3 values - SUM_OVER_BATCH_SIZE, SUM and NONE.

y_true = [[0., 2.], [0., 0.]]
y_pred = [[3., 1.], [2., 5.]]

mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5

mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 2.75

What I could infer about the calculation after various trials, is this:-

As a result, SUM_OVER_BATCH_SIZE is nothing but SUM/batch_size. Then, why is it called SUM_OVER_BATCH_SIZE when SUM actually adds up the losses over the entire batch, while SUM_OVER_BATCH_SIZE calculates the average loss of the batch.

Is my assumption regarding the workings of SUM_OVER_BATCH_SIZE and SUM at all correct?

Upvotes: 7

Views: 9147

Answers (1)

CristoJV
CristoJV

Reputation: 500

Your assumption is correct as far as I understand.

If you check the github [keras/losses_utils.py][1] lines 260-269 you will see that it does performs as expected. SUM will sum up the losses in the batch dimension, and SUM_OVER_BATCH_SIZE would divide SUM by the number of total losses (batch size).

def reduce_weighted_loss(weighted_losses,
                     reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
  if reduction == ReductionV2.NONE:
     loss = weighted_losses
  else:
     loss = tf.reduce_sum(weighted_losses)
     if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
        loss = _safe_mean(loss, _num_elements(weighted_losses))
  return loss

You can do a easy checking with your previous example just by adding one pair of outputs with 0 loss.

y_true = [[0., 2.], [0., 0.],[1.,1.]]
y_pred = [[3., 1.], [2., 5.],[1.,1.]]

mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5

mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 1.8333

So, your assumption is correct. [1]: https://github.com/keras-team/keras/blob/v2.7.0/keras/utils/losses_utils.py#L25-L84

Upvotes: 5

Related Questions