salty_coffee
salty_coffee

Reputation: 631

In cross_val_score, how is the parameter cv being used differently?

I'm trying to figure how to do k-fold cross validation. I was hoping someone could tell me the difference between my two print statements. They give me largely different data and I thought they would be the same.

##train is my training data, 
##target is my target, my binary class.

dtc = DecisionTreeClassifier()
kf = KFold(n_splits=10)
print(cross_val_score(dtc, train, target, cv=kf, scoring='accuracy'))
print(cross_val_score(dtc, train, target, cv=10, scoring='accuracy'))

Upvotes: 2

Views: 3142

Answers (1)

miradulo
miradulo

Reputation: 29680

DecisionTreeClassifier derives from ClassifierMixin, and so as mentioned in the docs (emphasis mine):

Computing cross-validated metrics

When the cv argument is an integer, cross_val_score uses the KFold or StratifiedKFold strategies by default, the latter being used if the estimator derives from ClassifierMixin.

So here when you are passing cv=10 you are using the StratifiedKFold strategy, whereas when passing cv=kf you are using the regular KFold strategy.

In classification, stratification generally attempts to ensure that each test fold has approximately equal class representation. See Understanding stratified cross-validation on Cross-Validated for some more info.

Upvotes: 1

Related Questions