Mitar
Mitar

Reputation: 6980

How to run only one fold of cross validation in sklearn?

I have he following code to run a 10-fold cross validation in SkLearn:

cv = model_selection.KFold(n_splits=10, shuffle=True, random_state=0)
scores = model_selection.cross_val_score(MyEstimator(), x_data, y_data, cv=cv, scoring='mean_squared_error') * -1

For debugging purposes, while I am trying to make MyEstimator work, I would like to run only one fold of this cross-validation, instead of all 10. Is there an easy way to keep this code but just say to run the first fold and then exit?

I would still like that data is split into 10 parts, but that only one combination of that 10 parts is fitted and scored, instead of 10 combinations.

Upvotes: 1

Views: 1275

Answers (1)

Vivek Kumar
Vivek Kumar

Reputation: 36599

No, not with cross_val_score I suppose. You can set n_splits to minimum value of 2, but still that will be 50:50 split of train, test which you may not want.

If you want maintain a 90:10 ration and test other parts of code like MyEstimator(), then you can use a workaround.

You can use KFold.split() to get the first set of train and test indices and then break the loop after first iteration.

cv = model_selection.KFold(n_splits=10, shuffle=True, random_state=0)
for train_index, test_index in cv.split(x_data):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = x_data[train_index], x_data[test_index]
    y_train, y_test = y_data[train_index], y_data[test_index]
    break

Now use this X_train, y_train to train the estimator and X_test, y_test to score it.

Instead of :

scores = model_selection.cross_val_score(MyEstimator(), 
                                         x_data, y_data, 
                                         cv=cv, 
                                         scoring='mean_squared_error')

Your code becomes:

myEstimator_fitted = MyEstimator().fit(X_train, y_train)
y_pred = myEstimator_fitted.predict(X_test)

from sklearn.metrics import mean_squared_error

# I am appending to a scores list object, because that will be output of cross_val_score.
scores = []
scores.append(mean_squared_error(y_test, y_pred))

Rest assured, cross_val_score will be doing this only internally, just some enhancements for parallel processing.

Upvotes: 2

Related Questions