Gere
Gere

Reputation: 12697

How to implement a meta-estimator with the scikit-learn API?

I would like to implement a simple wrapper / meta-estimator which is compatible with all of scikit-learn. It is hard to find a full description of what exactly I need.

The goal is to have a regressor which also learns a threshold to become a classifier. So I came up with:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

Is this implement the full API I need?

My main question is where to put the threshold. I want that it gets learned only once and can be re-used in subsequent .fit calls with new data without being readjusted. But with the current version it has to be retuned on every .fit call - which I do not want?

On the other hand, if I make it a fixed parameter self.threshold and pass it to __init__, then I'm not supposed to change it with the data?

How can I make a threshold parameter which can be tuned in one call of .fit and be fixed for subsequent .fit calls?

Upvotes: 9

Views: 1266

Answers (2)

Adithya
Adithya

Reputation: 63

I actually wrote a blog post about this the other day. I assume you are trying to build something similar to TransformedTargetRegressor I would suggest taking a look at its source code to build something similar.

Your current implementation seems about right. As far as this concern goes:

How can I make a threshold parameter which can be tuned in one call of .fit and be fixed for subsequent .fit calls?

I would suggest against that because scikit-learn's API is based around the fit method re-fitting all tunable aspects of the model. There are two routes you can go here, either add a **kwarg to the fit that explicitly protects the theshold from updating or you can go with what @rotem-tal suggested. If you choose the latter, it might look something like this:

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin

def optimal_threshold(y_raw: np.ndarray) -> np.ndarray:
    return np.array([0.1, 0.5, 1])  # some implementation here

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        self.threshold = None

    def fit(self, X, y, optimal_threshold):
        # you don't need to clone the regressor
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        if self.threshold is None:
            self.threshold = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

Upvotes: 1

Anatoly Alekseev
Anatoly Alekseev

Reputation: 2410

A totally legit question. Yes, to esure compatibility, you must

  1. not do anything except params persisting in the init
  2. clone the inner estimator in the fit. Just use the underscore: self.regressor_ = clone(self.regressor)
  3. to allow for more flexibility, it's probably better to have

def fit(self, X, y, **fit_params): optimal_threshold=fit_params.get('optimal_threshold',0.5) self.regressor_.fit(X, y, **fit_params) instead of just fit(self, X, y, optimal_threshold)

and for (potentially) better performance

y_raw = self.regressor_.fit_predict(X, y, **fit_params)
  1. add to fit

    # Check that X and y have correct shape
    
     X, y = check_X_y(X, y)
    
  2. add to predict

    Check is fit had been called

    check_is_fitted(self)

    Input validation

    X = check_array(X)

  3. ensure that conformity checks pass, say:

    from sklearn.utils.estimator_checks import check_estimator from sklearn.linear_model import LinearRegression

    check_estimator(Thresholder(regressor=LinearRegression()))

And of course, read the dev guide, especially if you need setting random_state.

Upvotes: 0

Related Questions