Reputation: 9078
I am trying to inherit from BaseEstimator
and MetaEstimatorMixin
to create a wrapping for a base_estimator
but I am facing problems. I was trying to follow the base_ensemble
code in the repository but it didn't help. I get TypeError: get_params() missing 1 required positional argument: 'self'
when running the test below which calls check_estimator(Wrapper)
. According to documentation I don't have to implement get_params if I inherit from BaseEstimator
. It seems like something is a class rather than an instance but I am not able to nail it down.
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, MetaEstimatorMixin, clone
from functools import lru_cache
import numpy as np
from sklearn.linear_model import LogisticRegression
'''
this is a module containing classes which wraps a classifier or a regressor sklearn estimator
'''
class Wrapper(BaseEstimator, MetaEstimatorMixin):
def __init__(self, base_estimator=LogisticRegression, estimator_params=None):
super().__init__()
self.base_estimator = base_estimator
self.estimator_params = estimator_params
def fit(self, x, y):
self.model = self._make_estimator().fit(x,y)
def _make_estimator(self):
"""Make and configure a copy of the `base_estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators. taken from sklearn github
"""
estimator = self.base_estimator()
estimator.set_params(**dict((p, getattr(self, p))
for p in self.estimator_params))
return estimator
def predict(self, x):
self.model.predict(x)
import unittest
from sklearn.utils.estimator_checks import check_estimator
class Test(unittest.TestCase):
def test_check_estimator(self):
check_estimator(Wrapper)
Upvotes: 2
Views: 2824
Reputation: 9390
base_estimator
field must be initialized with object, instead of Class.
....
def __init__(self, base_estimator=LogisticRegression(), ...
....
Your error happens because clone(safe=False) is used in some of tests.
safe: boolean, optional If safe is false, clone will fall back to a deepcopy on objects that are not estimators.
Upvotes: 3