Nuria
Nuria

Reputation: 43

Plot precision and recall with sklearn

I have a created a classification model with a custom ML framework.

I have 3 classes: 1, 2, 3

Input sample:

# y_true, y_pred, and y_scores are lists

print(y_true[0], y_pred[0], y_scores[0])
print(y_true[1], y_pred[1], y_scores[1])
print(y_true[2], y_pred[2], y_scores[2])

1 1 0.6903580037019461
3 3 0.8805178752523366
1 2 0.32107199420078963

Using sklearn I'm able to use: metrics.classification_report:

metrics.classification_report(y_true, y_pred)

                         precision    recall  f1-score   support

                      1      0.521     0.950     0.673        400
                      2      0.000     0.000     0.000        290
                      3      0.885     0.742     0.807        310

               accuracy                          0.610       1000
              macro avg      0.468     0.564     0.493       1000
           weighted avg      0.482     0.610     0.519       1000

I want to generate precision vs recall visualization.

But I get this error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-108-2ebb913a4e4b> in <module>()
----> 1 precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_scores)

1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    534     if not (y_type == "binary" or
    535             (y_type == "multiclass" and pos_label is not None)):
--> 536         raise ValueError("{0} format is not supported".format(y_type))
    537 
    538     check_consistent_length(y_true, y_score, sample_weight)

ValueError: multiclass format is not supported 

I found some examples:

But not very clear how to binarize my array if I already have the results, Looking for pointers how to simply plot it.

Upvotes: 1

Views: 1366

Answers (1)

Ben Reiniger
Ben Reiniger

Reputation: 12582

precision_recall_curve has a parameter pos_label, the label of the "positive" class for the purposes of TP/TN/FP/FN. So you can extract the relevant probability and then generate the precision/recall points as:

y_pred = model.predict_proba(X)

index = 2  # or 0 or 1; maybe you want to loop?
label = model.classes_[index]  # see below
p, r, t = precision_recall_curve(y_true, y_pred[:, index], pos_label=label)

The main obnoxiousness here is that you need to extract the column of y_pred by index, but pos_label expects the actual class label. You can connect those using model.classes_.

It's probably also worth noting that the new plotting convenience function plot_precision_recall_curve doesn't work with this: it takes the model as a parameter, and breaks if it is not a binary classification.

Upvotes: 1

Related Questions