Reputation: 451
I am training a neural network for multilabel classification, with a large number of classes (1000). Which means more than one output can be active for every input. On an average, I have two classes active per output frame. On training with a cross entropy loss the neural network resorts to outputting only zeros, because it gets the least loss with this output since 99.8% of my labels are zeros. Any suggestions on how I can push the network to give more weight to the positive classes?
Upvotes: 13
Views: 4125
Reputation: 1098
Many thanks to tobigue for the great solution.
The tensorflow and keras apis have changed since that answer. So the updated version of weighted_binary_crossentropy
is below for Tensorflow 2.7.0.
import tensorflow as tf
POS_WEIGHT = 10
def weighted_binary_crossentropy(target, output):
"""
Weighted binary crossentropy between an output tensor
and a target tensor. POS_WEIGHT is used as a multiplier
for the positive targets.
Combination of the following functions:
* keras.losses.binary_crossentropy
* keras.backend.tensorflow_backend.binary_crossentropy
* tf.nn.weighted_cross_entropy_with_logits
"""
# transform back to logits
_epsilon = tf.convert_to_tensor(tf.keras.backend.epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, _epsilon, 1 - _epsilon)
output = tf.math.log(output / (1 - output))
loss = tf.nn.weighted_cross_entropy_with_logits(labels=target, logits=output, pos_weight=POS_WEIGHT)
return tf.reduce_mean(loss, axis=-1)
Upvotes: 0
Reputation: 3607
Tensorflow has a loss function weighted_cross_entropy_with_logits
, which can be used to give more weight to the 1's. So it should be applicable to a sparse multi-label classification setting like yours.
From the documentation:
This is like sigmoid_cross_entropy_with_logits() except that pos_weight, allows one to trade off recall and precision by up- or down-weighting the cost of a positive error relative to a negative error.
The argument pos_weight is used as a multiplier for the positive targets
If you use the tensorflow backend in Keras, you can use the loss function like this (Keras 2.1.1):
import tensorflow as tf
import keras.backend.tensorflow_backend as tfb
POS_WEIGHT = 10 # multiplier for positive targets, needs to be tuned
def weighted_binary_crossentropy(target, output):
"""
Weighted binary crossentropy between an output tensor
and a target tensor. POS_WEIGHT is used as a multiplier
for the positive targets.
Combination of the following functions:
* keras.losses.binary_crossentropy
* keras.backend.tensorflow_backend.binary_crossentropy
* tf.nn.weighted_cross_entropy_with_logits
"""
# transform back to logits
_epsilon = tfb._to_tensor(tfb.epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, _epsilon, 1 - _epsilon)
output = tf.log(output / (1 - output))
# compute weighted loss
loss = tf.nn.weighted_cross_entropy_with_logits(targets=target,
logits=output,
pos_weight=POS_WEIGHT)
return tf.reduce_mean(loss, axis=-1)
Then in your model:
model.compile(loss=weighted_binary_crossentropy, ...)
I have not found many resources yet which report well working values for the pos_weight
in relation to the number of classes, average active classes, etc.
Upvotes: 9