Hanan Shteingart
Hanan Shteingart

Reputation: 9078

sklearn wrapping an estimator causes error on get_params missing self

I am trying to inherit from BaseEstimator and MetaEstimatorMixin to create a wrapping for a base_estimator but I am facing problems. I was trying to follow the base_ensemble code in the repository but it didn't help. I get TypeError: get_params() missing 1 required positional argument: 'self' when running the test below which calls check_estimator(Wrapper). According to documentation I don't have to implement get_params if I inherit from BaseEstimator. It seems like something is a class rather than an instance but I am not able to nail it down.

from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, MetaEstimatorMixin, clone
from functools import lru_cache
import numpy as np
from sklearn.linear_model import LogisticRegression

'''
this is a module containing classes which wraps a classifier or a regressor sklearn estimator
'''


class Wrapper(BaseEstimator, MetaEstimatorMixin):
    def __init__(self, base_estimator=LogisticRegression, estimator_params=None):
        super().__init__()
        self.base_estimator = base_estimator
        self.estimator_params = estimator_params

    def fit(self, x, y):
        self.model = self._make_estimator().fit(x,y)

    def _make_estimator(self):
        """Make and configure a copy of the `base_estimator_` attribute.
        Warning: This method should be used to properly instantiate new
        sub-estimators. taken from sklearn github
        """
        estimator = self.base_estimator()
        estimator.set_params(**dict((p, getattr(self, p))
                                    for p in self.estimator_params))

        return estimator

    def predict(self, x):
        self.model.predict(x)


import unittest
from sklearn.utils.estimator_checks import check_estimator
class Test(unittest.TestCase):
    def test_check_estimator(self):
        check_estimator(Wrapper)

Upvotes: 2

Views: 2824

Answers (1)

Ibraim Ganiev
Ibraim Ganiev

Reputation: 9390

base_estimator field must be initialized with object, instead of Class.

....
def __init__(self, base_estimator=LogisticRegression(), ...
....

Your error happens because clone(safe=False) is used in some of tests.

safe: boolean, optional
    If safe is false, clone will fall back to a deepcopy on objects
    that are not estimators.

Upvotes: 3

Related Questions