Reputation: 631
I'm trying to figure how to do k-fold cross validation. I was hoping someone could tell me the difference between my two print statements. They give me largely different data and I thought they would be the same.
##train is my training data,
##target is my target, my binary class.
dtc = DecisionTreeClassifier()
kf = KFold(n_splits=10)
print(cross_val_score(dtc, train, target, cv=kf, scoring='accuracy'))
print(cross_val_score(dtc, train, target, cv=10, scoring='accuracy'))
Upvotes: 2
Views: 3142
Reputation: 29680
DecisionTreeClassifier
derives from ClassifierMixin
, and so as mentioned in the docs (emphasis mine):
Computing cross-validated metrics
When the
cv
argument is an integer,cross_val_score
uses theKFold
orStratifiedKFold
strategies by default, the latter being used if the estimator derives fromClassifierMixin
.
So here when you are passing cv=10
you are using the StratifiedKFold
strategy, whereas when passing cv=kf
you are using the regular KFold
strategy.
In classification, stratification generally attempts to ensure that each test fold has approximately equal class representation. See Understanding stratified cross-validation on Cross-Validated for some more info.
Upvotes: 1