Ines
Ines

Reputation: 19

weighted loss function for multilabel classification

I am working on multilabel classification problem for images. I have 5 classes and I am using sigmoid for the last layer of classification. I have imbalanced data caused by multilabel problem and I thought I can use:

tf.nn.weighted_cross_entropy_with_logits( labels, logits, pos_weight, name=None)

However I don't know how to get logits from my model. I also think I shouldn't use sigmoid in the last layer since this loss function applies sigmoid to the logit.

Upvotes: 1

Views: 1404

Answers (1)

ClaudiaR
ClaudiaR

Reputation: 3434

First of all I suggest you have a look at the TensorFlow tutorial for classification on imbalanced dataset. However keep in mind that this tutorial is for binary classification and uses a sigmoid as last dense layer activation function. For multi-label classification you should use a softmax activation. The softmax function normalizes a set of N real numbers into a probability distribution such that they sum up to 1. For K = 2, the softmax and sigmoid function are the same.

I don't know your model, but you could create something like this (following the tutorial):

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=None)
])

To obtain the predictions you could do:

predictions = model(x_train[:1]).numpy()  # obtains the prediction logits
tf.nn.softmax(predictions).numpy()  # converts the logits to probabilities

In order to train you can define the following loss, compile the model, and train:

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
          loss=loss_fn,
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

Now, since you have an imbalanced dataset, in order to add weights, if you look at the documentation of SparseCategoricalCrossEntropy, you can see that the __call__ method has an optional parameter sample_weights:

Optional sample_weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector.

I suggest you have a look at this answer if you have doubts on how to proceed. I think it answers perfectly what you want to achieve.

Also I find that this tutorial explains pretty well the multi-label classification problem.

Upvotes: 1

Related Questions