martian_rover
martian_rover

Reputation: 341

"ValueError: multilabel-indicator format is not supported" for roc_curve() sklearn

I am trying to get tpr(true positive rate) and fpr(false positive rate) from roc_curve() and then auc score() and then can plot the graph to see how my model is behaving on multi-label (500 labels) imbalanced data but getting the error.

I am calculating probability of each label prediction so that I can change the threshold to get better precision, recall and accuracy and to get most target labels while predicting.

Code:

from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import ClassifierChain
rfc = RandomForestClassifier(n_jobs = -1, random_state =0, class_weight = 'balanced')
clf2 = ClassifierChain(rfc)
clf2.fit(X_train , y_train)
y_pred = clf2.predict_proba(X_test)

y_pred.shape
>> (8125,500)

y_pred[0]
>> array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.01, 0.  , 0.  , 0.  ,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.01, 0.  , 0.01, 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.03, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.5 , 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.05, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.02, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.03, 0.04, 0.  ,
        0.  , 0.  , 0.01, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.02, 0.  ,
        0.  , 0.01, 0.  , 0.01, 0.  , 0.28, 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ,
        0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.02, 0.07, 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.02, 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.02, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.02, 0.01, 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.  , 0.03, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.15, 0.  , 0.  , 0.02, 0.  ,
        0.01, 0.  , 0.11, 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.02, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.02, 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.1 , 0.02, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.02,
        0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.01, 0.  , 0.  , 0.01, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  ]])

from sklearn.metrics import roc_auc_score,roc_curve,precision_recall_curve
fpr, tpr, thresholds = roc_curve(y_test,y_pred)

The last line of code gives me the error.

Traceback:

ValueError                                Traceback (most recent call last)

<ipython-input-72-ea45ece64953> in <module>()
      1 from sklearn.metrics import roc_auc_score,roc_curve,precision_recall_curve
----> 2 fpr, tpr, thresholds = roc_curve(y_test,y_pred)

1 frames

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    534     if not (y_type == "binary" or
    535             (y_type == "multiclass" and pos_label is not None)):
--> 536         raise ValueError("{0} format is not supported".format(y_type))
    537 
    538     check_consistent_length(y_true, y_score, sample_weight)

ValueError: multilabel-indicator format is not supported

Upvotes: 1

Views: 7032

Answers (1)

amiola
amiola

Reputation: 3026

Here point is that, as stated in the docs for sklearn.metrics.roc_curve(),

Note: this implementation is restricted to the binary classification task.

while your target data (y_train and y_test) is multilabel (sklearn.utils.multiclass.type_of_target(y_train) is 'multilabel-indicator').

This said, there are different ways to evaluate a multilabel (or a multioutput) classifier; one approach consists in measuring a metric for each individual label and then averaging them across all labels (the so-called macro-averaging, which is not the only method though; see here for more references).

In the case of a ROC curve, this would mean drawing a ROC curve per label/class by first training n_classes binary classifiers (OvA strategy) or, as in your case, by exploiting an inherently multilabel classifier. Then, as shown here, you might also compute and draw a macro-average ROC curve. Accordingly, depending on the kind of averaging method exploited, you could have different ways of extending this binary metric into a multilabel setting.

Upvotes: 1

Related Questions