Reputation: 3256
GridSearchCV implements a fit method in which it performs n-fold cross validation to determine best parameters. After this we can directly apply the best estimator to the testing data using predict() - Following this link : - http://scikit-learn.org/stable/auto_examples/grid_search_digits.html
It says here "The model is trained on the full development set"
However we have only applied n fold cross validations here. Is the classifier somehow also training itself on the entire data? or is it just choosing the best trained estimator with best parameters amongst the n-folds when applying predict?
Upvotes: 3
Views: 3009
Reputation: 1062
If you want to use predict
, you'll need to set 'refit'
to True
. From the documentation:
refit : boolean
Refit the best estimator with the entire dataset.
If “False”, it is impossible to make predictions using
this GridSearchCV instance after fitting.
It looks like it is true by default, so in the example, predict
is based on the whole training set.
Upvotes: 5