sereizam
sereizam

Reputation: 2130

Can't figure out class ordering in precision_recall_fscore_support return values in scikit-learn

After numerous tries, I can't understand the way one can recover the classes metrics from precision_recall_fscore_support return values.

For example, given this classical learning context:

target_names = set(y)
y = [target_names.index(x) for x in y]
X_train, X_test, y_train, y_test = train_test_split(X, y)

# Some classification ...

y_pred = clf.predict(X_test)
precision, recall, f1, support = precision_recall_fscore_support(y_test, y_pred)

Here, len(set(y_test)) == len(support) so I imagine that all classes present in y_test are present in the return values. But I can't find the way they are ordered, so I can recover which metrics correspond to which class.

Thanks for your help !

Upvotes: 0

Views: 809

Answers (1)

Vivek Kumar
Vivek Kumar

Reputation: 36617

The labels are in sorted order. Quoting the documentation :-

By default, all labels in y_true and y_pred are used in sorted order

The order of classes is decided by the labels parameter in precision_recall_fscore_support. If not supplied any, then the default behaviour is to collect all classes in y_pred and y_true and arranged in sorted order.

Documentation Example:

y_true = np.array(['cat', 'pig', 'dog', 'cat', 'dog', 'pig'])
y_pred = np.array(['cat', 'dog', 'pig', 'cat', 'cat', 'dog'])

precision_recall_fscore_support(y_true, y_pred)

Output:

(array([ 0.66666667,  0.        ,  0.        ]),
 array([ 1.,  0.,  0.]),
 array([ 0.8,  0. ,  0. ]),
 array([2, 2, 2]))

The above tuple have 4 arrays (precision, recall, f_score and support) and each array has 3 elements, one each for 'cat', 'dog' and 'pig'. (As you can yourself calculate that the metrics are arranged according to sorted classes 'cat', 'dog', 'pig').

Even if you change the order of labels here:-

y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig'])
y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog'])

the output will be same:-

(array([ 0.66666667,  0.        ,  0.        ]),
 array([ 1.,  0.,  0.]),
 array([ 0.8,  0. ,  0. ]),
 array([2, 2, 2]))

Same happens if y have numerical values.

Hope it clears your doubts. Feel free to ask any doubt.

Upvotes: 2

Related Questions