Joseph Adam
Joseph Adam

Reputation: 1652

How to define a weighted loss function for TF2.0+ keras CNN for image classification?

I would like to integrate the weighted_cross_entropy_with_logits to deal with data imbalance. I am not sure how to do it. Class 0 has 10K images, while class 1 has 500 images. Here is my code.

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), input_shape=(dim, dim, 3), activation='relu'),
    ....
    tf.keras.layers.Dense(2, activation='softmax')
])

model.compile(optimizer="nadam",
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=['accuracy'])


class_weight = {0: 1.,
                1: 20.}

model.fit(
    train_ds,
    val_ds,
    epochs=epc,
    verbose=1,
    class_weight=class_weight)

Upvotes: 6

Views: 2017

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

You can simply wrap tf.nn.weighted_cross_entropy_with_logits inside a custom loss function.

Remember also that tf.nn.weighted_cross_entropy_with_logits expects logits so your network must produce it and not probabilities (remove softmax activation from the last layer)

Here a dummy example:

X = np.random.uniform(0,1, (10,32,32,3))
y = np.random.randint(0,2, (10,))
y = tf.keras.utils.to_categorical(y)

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), input_shape=(32, 32, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2) ### must be logits (remove softmax)
])

def my_loss(weight):
    def weighted_cross_entropy_with_logits(labels, logits):
        loss = tf.nn.weighted_cross_entropy_with_logits(
            labels, logits, weight
        )
        return loss
    return weighted_cross_entropy_with_logits

model.compile(optimizer="nadam",
              loss=my_loss(weight=0.8),
              metrics=['accuracy'])
model.fit(X,y, epochs=3)

At inference time you obtain the probabilities in this way:

tf.nn.softmax(model.predict(X))

Upvotes: 5

Related Questions