Reputation: 679
I need to build my custom transformer, use it in pipeline and evaluate it tune parameters of that pipeline using GridSearchCV.
I managed to implement simple custom transformers, following the advices from here, but the questions occured while trying implement transformer with inner estimator and use this construction in GridSearchCV. I can not find the answer myself, as it seems to me, because I do not fully understand the subtleties of Search methods like (Grid/Randomized)SearchCV and set_params.
"Introduction to ML with Python" book describe GridSearchCV logic rather naive:
...iterating over each parameters combination...
init estimator
fit estimator
evaluate
But this naive approach can't answer to my question. To clarify my problem, lets' look at this case:
class OuterTransformer(BaseEstimator, TransformerMixin):
_options = {'std':StandardScaler(),'mm':MinMaxScaler()}
def __init__(self, option='std'):
...
The main question for me is "where do I put the logic of choosing an internal estimator?". According to post mentioned above, this should look something like this:
def __init__(self, option='std'):
self.option = option
def fit(self, data, y=None):
self.option = self._options[option]
...
On the other hand, common sense dictates that GridSearch must pass parameters to initialize the internal estimator before calling fit, so the internal estimator should be selected in __init__.
It seems that first way works fine, but I just can not understand why. Can please somebody explain this phenomenon to me?
Upvotes: 2
Views: 1562
Reputation: 679
It looks like I understood the logic of initialization and re-initialization of parameters of estimators. This helped to answer my question:
The class fields must be initialized with those original values that were passed to the constructor, rather than some "derivatives" from them, because for each re-initialization of the estimator, scikit calls __init__, passing parameters that were extracted from the instance by the method get_params before CV start.
And the essence of get_params is that it scans the signature of the method __init__ of the class and pulls from the instance of estimator fields with names corresponding to the names of the arguments of __init__(except self of course).
Thus, if we write "derived" values into fields inside __init__ method, these "derived" values will be transferred to the next re-initialization, and that means everything will fail.
class OuterTransformer(BaseEstimator, TransformerMixin):
_options = {'std':StandardScaler(),'mm':MinMaxScaler()}
# good init- all fine
def __init__(self, option='std'):
self.option = option
# bad init - will not work, because option is not an 'original' parameter.
def __init__(self, option='std'):
self.option = self._options[option]
Upvotes: 2