Reputation: 433
I am training a multi-label classification model for detecting attributes of clothes. I am using transfer learning in Keras, retraining the last few layers of the vgg-19 model.
The total number of attributes is 1000 and about 99% of them are 0s. Metrics like accuracy, precision, recall, etc., all fail, as the model can predict all zeroes and still achieve a very high score. Binary cross-entropy, hamming loss, etc., haven't worked in the case of loss functions.
I am using the deep fashion dataset.
So, which metrics and loss functions can I use to measure my model correctly?
Upvotes: 28
Views: 56181
Reputation: 645
Categorical Cross-Entropy loss or Softmax Loss is a Softmax activation plus a Cross-Entropy loss. If we use this loss, we will train a CNN to output a probability over the C classes for each image. It is used for multi-class classification.
What you want is multi-label classification, so you will use Binary Cross-Entropy Loss or Sigmoid Cross-Entropy loss. It is a Sigmoid activation plus a Cross-Entropy loss. Unlike Softmax loss it is independent for each vector component (class), meaning that the loss computed for every CNN output vector component is not affected by other component values. That’s why it is used for multi-label classification, where the insight of an element belonging to a certain class should not influence the decision for another class.
Now for handling class imbalance, you can use weighted Sigmoid Cross-Entropy loss. So you will penalize for wrong prediction based on the number/ratio of positive examples.
Upvotes: 45
Reputation: 382
You can refer to this github. They have binary, multi-class, multi-labels and also options to enforce model to learn close to 0 and 1 or simply learn probability.
Steve
Upvotes: 1
Reputation: 339
Actually you should use tf.nn.weighted_cross_entropy_with_logits
.
It not only for multi label classification and also has a pos_weight
can pay much attention at the positive classes as you would expected.
Upvotes: 9
Reputation: 836
Multi-class and binary-class classification determine the number of output units, i.e. the number of neurons in the final layer. Multi-label and single-Label determines which choice of activation function for the final layer and loss function you should use. For single-label, the standard choice is Softmax with categorical cross-entropy; for multi-label, switch to Sigmoid activations with binary cross-entropy.
Categorical Cross-Entropy:
Binary Cross-Entropy:
C
is the number of classes, and m
is the number of examples in the current mini-batch. L
is the loss function and J
is the cost function. You can also see here.
In the loss function, you are iterating over different classes. In the cost function, you are iterating over the examples in the current mini-batch.
Upvotes: 0
Reputation: 82
I have been in a simialr situation like yours
you can use softmax activation function in the output layer with categorical_crossentropy to check other metrics such as precision, recall and f1 score you can use the sklearn library as follows:
from sklearn.metrics import classification_report
y_pred = model.predict(x_test, batch_size=64, verbose=1)
y_pred_bool = np.argmax(y_pred, axis=1)
print(classification_report(y_test, y_pred_bool))
as for the training stage as far as know there is the accuracy metric as follows
model.compile(loss='categorical_crossentropy'
, metrics=['acc'], optimizer='adam')
if it helps you, you can plot the training history for the loss and accuracy of your training stage using matplotlib as follows :
hist = model.fit(x_train, y_train, batch_size=24, epochs=1000, verbose=2,
callbacks=[checkpoint],
validation_data=(x_valid, y_valid)
)
# Plot training & validation accuracy values
plt.plot(hist.history['acc'])
plt.plot(hist.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# Plot training & validation loss values
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
Upvotes: -1