pgol
pgol

Reputation: 55

How to calculate precision, recall in multiclass classification problem after each epoch during training?

I am using Tensorflow 1.15.0 and keras 2.3.1.I'm trying to calculate precision and recall of six class classification problem of each epoch for my training data and validation data during training. I can use the classification_report but it works only after training has completed.

from sklearn.metrics import classification_report
y_pred = final.predict(X_test)
y_indx = np.argmax(y_test_new, axis = 1)
pred_indx = np.argmax(y_pred, axis = 1)
print(classification_report(y_indx, pred_indx))

The result for network ResNet154 is like below and my dataset is balanced.

             precision    recall  f1-score   support

       0       0.00      0.00      0.00    172482
       1       0.00      0.00      0.00    172482
       2       0.00      0.00      0.00    172482
       3       0.00      0.00      0.00    172482
       4       0.00      0.00      0.00    172482
       5       0.17      1.00      0.29    172482


accuracy                           0.17   1034892
macro avg       0.03      0.17      0.05   1034892
weighted avg       0.03      0.17      0.05   1034892

I just want to check precision and recall and f1-score of my training data by using callbacks to be sure that whether or not it is overfitting of network.

Upvotes: 3

Views: 2542

Answers (1)

Timbus Calin
Timbus Calin

Reputation: 15063

You need to define a specific callback in order to do this.

One solution to your problem is available in the following article: https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2.

The article above mentions how to calculate your desired metrics at the end of each epoch.

Otherwise, you can define a custom callback in which you have the access to your validation set; in the on_epoch_end(), you get the number of TP, TN, FN, FP, with which you can calculate all the metrics that you want.

Also, you can check this example written here (work on TensorFlow 2.X versions, >=2.1) : How to get other metrics in Tensorflow 2.0 (not only accuracy)?

Upvotes: 1

Related Questions