Reputation: 8297
I am trying to perform scaling using StandardScaler and define a KNeighborsClassifier(Create pipeline of scaler and estimator)
Finally, I want to create a Grid Search cross validator for the above where param_grid will be a dictionary containing n_neighbors as hyperparameter and k_vals as values.
def kNearest(k_vals):
skf = StratifiedKFold(n_splits=5, random_state=23)
svp = Pipeline([('ss', StandardScaler()),
('knc', neighbors.KNeighborsClassifier())])
parameters = {'n_neighbors': k_vals}
clf = GridSearchCV(estimator=svp, param_grid=parameters, cv=skf)
return clf
But doing this will give me an error saying that
Invalid parameter n_neighbors for estimator Pipeline. Check the list of available parameters with `estimator.get_params().keys()`.
I've read the documentation, but still don't quite get what the error indicates and how to fix it.
Upvotes: 1
Views: 2553
Reputation: 40938
You are right, this is not exactly well-documented by scikit-learn. (Zero reference to it in the class docstring.)
If you use a pipeline as the estimator in a grid search, you need to use a special syntax when specifying the parameter grid. Specifically, you need to use the step name followed by a double underscore, followed by the parameter name as you would pass it to the estimator. I.e.
'<named_step>__<parameter>': value
In your case:
parameters = {'knc__n_neighbors': k_vals}
should do the trick.
Here knc
is a named step in your pipeline. There is an attribute that shows these steps as a dictionary:
svp.named_steps
{'knc': KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=5, p=2,
weights='uniform'),
'ss': StandardScaler(copy=True, with_mean=True, with_std=True)}
And as your traceback alludes to:
svp.get_params().keys()
dict_keys(['memory', 'steps', 'ss', 'knc', 'ss__copy', 'ss__with_mean', 'ss__with_std', 'knc__algorithm', 'knc__leaf_size', 'knc__metric', 'knc__metric_params', 'knc__n_jobs', 'knc__n_neighbors', 'knc__p', 'knc__weights'])
Some official references to this:
Upvotes: 2