Chris Gorgolewski
Chris Gorgolewski

Reputation: 621

Confusing example of nested cross validation in scikit-learn

I'm looking at this example from scikit-learn documentation: http://scikit-learn.org/0.18/auto_examples/model_selection/plot_nested_cross_validation_iris.html

It seems to me that crossvalidation is not performed in an unbiased way here. Both GridSearchCV (supposedly the inner CV loop) and cross_val_score (supposedly the outer CV loop) are using the same data and the same folds. Therefore there is an overlap between the data the classifier was trained on and evaluated with. What am I getting wrong?

Upvotes: 4

Views: 2168

Answers (4)

Arpit Omprakash
Arpit Omprakash

Reputation: 118

When I first read the documentation and most of the answers, I got very confused. But now I think I get it, so I will try to explain what I get, hope it helps :)

inner_cv = KFold(n_splits=4, shuffle=True, random_state=i)
outer_cv = KFold(n_splits=4, shuffle=True, random_state=i)

This part is quite straightforward. In the lines above, we are creating two cross validators called inner_cv and outer_cv.

# Non_nested parameter search and scoring
clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv)
clf.fit(X_iris, y_iris)
non_nested_scores[i] = clf.best_score_

This part is quite straightforward again. We ignore the outer_cv cross validator and only use the inner_cv cross validator. We are basically using GridSearchCV to find optimal hyperparameters for the whole dataset based on the inner_cv cross validator. These optimal hyperparameters are generated in the first line. Then in the second line of code, we are fitting the optimal model to the whole dataset, again by splitting using inner_cv. In the last step we are getting the scores from the fitted model.

This is non-nested cross validation. Since, both the hyperparameter tuning and the performance calculation are done using the same data split, there might be some data leakage as the model has already seen all the data before being fitted (while the hyperparameters are tuned). So, what the webpage suggests is to use the following:

# Nested CV with parameter optimization
clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv)
nested_score = cross_val_score(clf, X=X_iris, y=y_iris, cv=outer_cv)
nested_scores[i] = nested_score.mean()

In the first line of code here, we are instantiating the GridSearchCV object using the inner_cv cross validator (but not fitting it). In the second line, we are doing a lot of things. First, using cross_val_score and outer_cv we break the initial data into different splits, let's call it x_tr_0, x_ts_0, x_tr_1, x_ts_1, x_tr_2, x_ts_2, x_tr_3, x_ts_3 (since there are four splits, and in each split, we have training and testing data). The training data is passed on to the GridSearchCV method in each fold. So, the inner_cv cross validator works on the training data splits from the outer_cv cross validator. So, in the GridSearchCV method, we are basically breaking down x_tr_0 to x_tr_0_0, x_tr_0_1, x_tr_0_2, x_tr_0_3 splits. In these "inner" splits we are doing the hyperparameter tuning.

Once the optimal hyperparameters are calculated, we use these in the "outer" split for calculating the model performance. In this case, the model evaluation is done on the outer_cv split test data (which is unseen by the hyperparameter tuning "inner" split). This ensures that the performance values we are getting are more generalizable and there is no data leakage.

--- Answer shared from a previous answer by me linked here (https://stackoverflow.com/a/78544053/11426625)

Upvotes: 0

Gael Varoquaux
Gael Varoquaux

Reputation: 2476

They are not using the same data. Granted, the code of the example does not make it apparent, because the splits are not visible: the first split is done inside cross_val_score, and the second split is done inside GridSearchCV (that the whole point of the GridSearchCV object). Using functions and objects rather than hand-written for loops may make things less transparent, but it:

  1. Enables reuses
  2. Adds many "little things" that would render the for loop tedious, such as parallel computing, support for different scoring function, etc.
  3. Is actually safer in terms of avoid data leakage because our splitting code has been audited many many times.

If you are not convinced, take a look at the code of cross_val_score and GridSearchCV.

The example was improved recently to specify this in the comments: http://scikit-learn.org/dev/auto_examples/model_selection/plot_nested_cross_validation_iris.html

(pull request on https://github.com/scikit-learn/scikit-learn/pull/7949 )

Upvotes: 0

jorjasso
jorjasso

Reputation: 121

Totally agree, that nested-cv procedure is wrong, cross_val_score is taken the best hyperparameters computed by GridSearchCV and computing a cv score using such hyperparameters. In nested-cv, you need the outer loop for assessing model performance and the inner loop for model selection, such that, the portion of data used in the inner loop for model selection must not be the same used for assessing model performance. An example will be a LOOCV outer loop for assessing performance (or, it will be a 5cv, 10cv, or whatever you like) and a 10cv-fold for model selection with grid search in the inner loop. That means that, if you have N observations then you will perform model selection in the inner loop (using grid search and 10-CV, for example) on the N-1 observations, and you will asses the model performance on the LOO observation (or in the hold-out data sample if you choose another approach). (Note that you are estimating N best models in the sense of hyperparameters internally) . it will be helpful to have access to the link of the code of cross_val_score and GridSearchCV. Some references for nested CV are:

  • Christophe Ambroise and Georey J McLachlan. Selection bias in gene extraction on the basis of microarray gene-expression data. Proceedings of the national academy of sciences 99, 10 (2002), 6562 - 6566.
  • Gavin C Cawley and Nicola LC Talbot. On overfitting in model selection and subsequent selection bias in performance evaluation. Journal of Machine Learning Research 11, Jul (2010), 2079{2107.

Note: I did not find anything in the documentation of cross_val_score indicating that internally the hyperparameters are optimized using parameter search, grid search + cross-validation for example, on the k-1 folds of data, and using those optimized parameters on the hold-out data sample (what I am saying is different to the code in http://scikit-learn.org/dev/auto_examples/model_selection/plot_nested_cross_validation_iris.html)

Upvotes: 1

Shree
Shree

Reputation: 73

@Gael - As I cannot add a comment, I am posting this in the answer section. I am not sure what Gael means by "the first split is done inside cross_val_score, and the second split is done inside GridSearchCV (that the whole point of the GridSearchCV object)". Are you trying to imply that the cross_val_score function passes the (k-1)-fold data (used for training in outer loop) to the clf object ? That does not appear to be the case, as I can comment out the cross_val_score function and just set nested_score[i] to a dummy variable, and still obtain the exact same clf.best_score_. This implies that the GridSearchCV is evaluated separately and does use all available data, and not a subset of training data.

In nested CV, to the best of my understanding, the idea is that the inner loop will do the hyper-parameter search on a smaller subset of training data, and then the outer loop will use these parameters to do a cross-validation. One of the reasons for using smaller training data in the inner loop is to avoid information leakage. It doesn't appear that's what is happening here. The inner loop is first using all the data to search for hyper-parameters, which are then used for cross-validation in the outer loop. Thus, the inner loop has already seen all data and any testing done in the outer loop will suffer from information leakage. If I am mistaken, could you please point me to the section of code which you are referring to in your answer ?

Upvotes: 4

Related Questions