Valentin Calomme
Valentin Calomme

Reputation: 618

In scikit-learn, check_estimator for ClassifierMixin

I am trying to create a "Scikit-Learn compliant" classifier by extending BaseEstimator and ClassifierMixin. I have read the documentation on their website and I also tried to follow some guides online like this one

I can create estimators that pass the check_estimator() test. However, whenever I try to create a classifier, it never passes the test. Even the template that Scikit-Learn provides doesn't pass the test...

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array
from sklearn.utils.estimator_checks import check_classifiers_classes
from sklearn.metrics import euclidean_distances
import numpy as np

class MyCustomClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, param1=2):
        self.param1 = param1

    def fit(self, X, y=None, **kwargs):
        # Check that X and y have correct shape
        X, y = check_X_y(X, y)

        # Store the classes seen during fit
        self.classes_ = np.unique(y)

        self.X_ = X
        self.y_ = y

        return self


    def predict(self, X):
        # Check is fit had been called
        check_is_fitted(self, ['X_', 'y_', 'classes_'])

        # Input validation
        X = check_array(X)

        closest = np.argmin(euclidean_distances(X, self.X_), axis=1)

        return self.y_[closest]

MyCustomClassifier()

from sklearn.utils.estimator_checks import check_estimator

check_estimator(MyCustomClassifier)

It seems that I am forgetting some type of test that would raise an error because this is the error I get:

Traceback (most recent call last):
  File "C:/Users/vca/Google Drive/Internship/Skratch/supervised/logistic_regression.py", line 97, in <module>
    check_estimator(MyCustomClassifier)
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 265, in check_estimator
    check(name, estimator)
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\testing.py", line 291, in wrapper
    return fn(*args, **kwargs)
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 1729, in check_classifiers_regression_target
    assert_raises_regex(ValueError, msg, e.fit, X, y)
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 1258, in assertRaisesRegex
    return context.handle('assertRaisesRegex', args, kwargs)
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 176, in handle
    callable_obj(*args, **kwargs)
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 196, in __exit__
    self.obj_name))
  File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 134, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: ValueError not raised by fit

Anybody who successfully created a classifier that passes the test?

Upvotes: 2

Views: 2519

Answers (1)

Valentin Calomme
Valentin Calomme

Reputation: 618

I just found out how to fix this. One needs to run check_classification_targets in the fit. This apparently raises an error if regression targets are used.

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import euclidean_distances
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array

class MyCustomClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, param1=2):
        self.param1 = param1

    def fit(self, X, y=None, **kwargs):
        # Check that X and y have correct shape
        X, y = check_X_y(X, y)
        check_classification_targets(y)

        # Store the classes seen during fit
        self.classes_ = np.unique(y)

        self.X_ = X
        self.y_ = y

        return self

    def predict(self, X):
        # Check is fit had been called
        check_is_fitted(self, ['X_', 'y_', 'classes_'])

        # Input validation
        X = check_array(X)

        closest = np.argmin(euclidean_distances(X, self.X_), axis=1)

        return self.y_[closest]

Upvotes: 3

Related Questions