Enrico Gandini
Enrico Gandini

Reputation: 1015

Fix a parameter in a scikit-learn estimator

I need to fix the value of a parameter of a scikit-learn estimator. I still need to be able to change all the other parameters of the estimator, and to use the estimator within scikit-learn tools such as Pipelines and GridSearchCV.

I tried to define a new class inheriting from a scikit-learn estimator. For instance, here I am trying to create a new class that fixes n_estimators=5 of a RandomForestClassifier.

class FiveTreesClassifier(RandomForestClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.n_estimators = 5


fivetrees = FiveTreesClassifier()
randomforest = RandomForestClassifier(n_estimators=5)

# This passes.
assert fivetrees.n_estimators == randomforest.n_estimators
# This fails: the params of fivetrees is an empty dict.
assert fivetrees.get_params() == randomforest.get_params()

The fact that get_params() is not reliable means that I cannot use the new estimator within Pipelines and GridSearchCV (as explained here).

I am using scikit-learn 0.24.2, but I think it would actually be the same with newer versions.

I would prefer answers that let me define a new class while fixing the value of an hyperparameter. I would also accept answers that use other techniques. I would also be thankful of thorough explanations of why I should / should not do this!

Upvotes: 1

Views: 795

Answers (1)

Franco Piccolo
Franco Piccolo

Reputation: 7410

You can use functools.partial

NewEstimator = partial(RandomForestClassifier, n_estimators=5)
new_estimator = NewEstimator()

Upvotes: 1

Related Questions