pcpetepete
pcpetepete

Reputation: 65

sklearn classification_report ValueError: Unknown label type:

I am trying a simple classification report on the output from a Keras model prediction. The format of the inputs are two 1D arrays, but the error is still thrown.

    Y_pred = np.squeeze(model.predict(test_data[0:5]))
    classification_report(test_labels[0:5], Y_pred)


    ---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-235-49afd2f46d17> in <module>()
----> 1 classification_report(test_labels[0:5], Y_pred)

/Library/Python/2.7/site-packages/sklearn/metrics/classification.pyc in classification_report(y_true, y_pred, labels, target_names, sample_weight, digits)
   1356 
   1357     if labels is None:
-> 1358         labels = unique_labels(y_true, y_pred)
   1359     else:
   1360         labels = np.asarray(labels)

/Library/Python/2.7/site-packages/sklearn/utils/multiclass.pyc in unique_labels(*ys)
     97     _unique_labels = _FN_UNIQUE_LABELS.get(label_type, None)
     98     if not _unique_labels:
---> 99         raise ValueError("Unknown label type: %s" % repr(ys))
    100 
    101     ys_labels = set(chain.from_iterable(_unique_labels(y) for y in ys))

ValueError: Unknown label type: (array([-0.38947693,  0.18258421, -0.00295772, -0.06293461, -0.29382696]), array([-0.46586546,  0.1359883 , -0.00223112, -0.08303966, -0.29208803]))

Both of the inputs are of the same type, so I am confused why this would not work? I have tried changing the type explicitly to dtype=float and flattening the inputs, but it still does not work.

Upvotes: 1

Views: 3856

Answers (1)

Abhishek Thakur
Abhishek Thakur

Reputation: 17015

classification_report works only for classification problems.

If you have a classification problem (eg, binary), use the following

Y_pred = np.squeeze(model.predict(test_data[0:5]))
threshold = 0.5
classification_report(test_labels[0:5], Y_pred > threshold)

threshold will make everything greater than 0.5 (in example above), 1.0

Upvotes: 3

Related Questions