just_another_beginner
just_another_beginner

Reputation: 159

Sparse Categorical CrossEntropy causing NAN loss

So, I've been trying to implement a few custom losses, and so thought I'd start off with implementing SCE loss, without using the built in TF object. Here's the function I wrote for it.

def custom_loss(y_true, y_pred):
    print(y_true, y_pred)
    return tf.cast(tf.math.multiply(tf.experimental.numpy.log2(y_pred[y_true[0]]), -1), dtype=tf.float32)

y_pred is the set of probabilties, and y_true is the index of the correct one. This setup should work according to all that I've read, but it returns NAN loss.

I checked if there's a problem with the training loop, but it works prefectly with the builtin losses.

Could someone tell me what the problem is with this code?

Upvotes: 1

Views: 1413

Answers (1)

user11989081
user11989081

Reputation: 8654

You can replicate the SparseCategoricalCrossentropy() loss function as follows

import tensorflow as tf

def sparse_categorical_crossentropy(y_true, y_pred, clip=True):

    y_true = tf.convert_to_tensor(y_true, dtype=tf.int32)
    y_pred = tf.convert_to_tensor(y_pred, dtype=tf.float32)

    y_true = tf.one_hot(y_true, depth=y_pred.shape[1])

    if clip == True:
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)

    return - tf.reduce_mean(tf.math.log(y_pred[y_true == 1]))

Note that the SparseCategoricalCrossentropy() loss function applies a small offset (1e-7) to the predicted probabilities in order to make sure that the loss values are always finite, see also this question.

y_true = [1, 2]
y_pred = [[0.05, 0.95, 0.0], [0.1, 0.8, 0.1]]

print(tf.keras.losses.SparseCategoricalCrossentropy()(y_true, y_pred).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=True).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=False).numpy())
# 1.1769392
# 1.1769392
# 1.1769392

y_true = [1, 2]
y_pred = [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]

print(tf.keras.losses.SparseCategoricalCrossentropy()(y_true, y_pred).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=True).numpy())
print(sparse_categorical_crossentropy(y_true, y_pred, clip=False).numpy())
# 8.059048
# 8.059048
# inf

Upvotes: 4

Related Questions