Reputation: 31
Using scikit-learn to classify a binary problem. Getting perfect classification_report
(all 1's). Yet prediction gives 0.36
. How can it be?
I'm familiar with imbalanced labels. Yet I don't think this is the case here since f1
and the other score columns, as well as the confusion matrix, indicate perfect score.
# Set aside the last 19 rows for prediction.
X1, X_Pred, y1, y_Pred = train_test_split(X, y, test_size= 19,
shuffle = False, random_state=None)
X_train, X_test, y_train, y_test = train_test_split(X1, y1,
test_size= 0.4, stratify = y1, random_state=11)
clcv = DecisionTreeClassifier()
scorecv = cross_val_score(clcv, X1, y1, cv=StratifiedKFold(n_splits=4),
scoring= 'f1') # to balance precision/recall
clcv.fit(X1, y1)
y_predict = clcv.predict(X1)
cm = confusion_matrix(y1, y_predict)
cm_df = pd.DataFrame(cm, index = ['0','1'], columns = ['0','1'] )
print(cm_df)
print(classification_report( y1, y_predict ))
print('Prediction score:', clcv.score(X_Pred, y_Pred)) # unseen data
Output:
confusion:
0 1
0 3011 0
1 0 44
precision recall f1-score support
False 1.00 1.00 1.00 3011
True 1.00 1.00 1.00 44
micro avg 1.00 1.00 1.00 3055
macro avg 1.00 1.00 1.00 3055
weighted avg 1.00 1.00 1.00 3055
Prediction score: 0.36
Upvotes: 1
Views: 1018
Reputation: 22023
The issue is that you are overfitting.
There are lots of code that is not used, so let's prune:
# Set aside the last 19 rows for prediction.
X1, X_Pred, y1, y_Pred = train_test_split(X, y, test_size= 19,
shuffle = False, random_state=None)
clcv = DecisionTreeClassifier()
clcv.fit(X1, y1)
y_predict = clcv.predict(X1)
cm = confusion_matrix(y1, y_Pred)
cm_df = pd.DataFrame(cm, index = ['0','1'], columns = ['0','1'] )
print(cm_df)
print(classification_report( y1, y_Pred ))
print('Prediction score:', clcv.score(X_Pred, y_Pred)) # unseen data
So clearly, there is no cross validation here, and the obvious reason for a low prediction score is the overfitting of the decision tree classifier.
Use the score from the cross validation, and you should see the issue there directly.
Upvotes: 2