Reputation: 618
I am trying to create a "Scikit-Learn compliant" classifier by extending BaseEstimator and ClassifierMixin. I have read the documentation on their website and I also tried to follow some guides online like this one
I can create estimators that pass the check_estimator()
test. However, whenever I try to create a classifier, it never passes the test. Even the template that Scikit-Learn provides doesn't pass the test...
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array
from sklearn.utils.estimator_checks import check_classifiers_classes
from sklearn.metrics import euclidean_distances
import numpy as np
class MyCustomClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, param1=2):
self.param1 = param1
def fit(self, X, y=None, **kwargs):
# Check that X and y have correct shape
X, y = check_X_y(X, y)
# Store the classes seen during fit
self.classes_ = np.unique(y)
self.X_ = X
self.y_ = y
return self
def predict(self, X):
# Check is fit had been called
check_is_fitted(self, ['X_', 'y_', 'classes_'])
# Input validation
X = check_array(X)
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
return self.y_[closest]
MyCustomClassifier()
from sklearn.utils.estimator_checks import check_estimator
check_estimator(MyCustomClassifier)
It seems that I am forgetting some type of test that would raise an error because this is the error I get:
Traceback (most recent call last):
File "C:/Users/vca/Google Drive/Internship/Skratch/supervised/logistic_regression.py", line 97, in <module>
check_estimator(MyCustomClassifier)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 265, in check_estimator
check(name, estimator)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\testing.py", line 291, in wrapper
return fn(*args, **kwargs)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\site-packages\sklearn\utils\estimator_checks.py", line 1729, in check_classifiers_regression_target
assert_raises_regex(ValueError, msg, e.fit, X, y)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 1258, in assertRaisesRegex
return context.handle('assertRaisesRegex', args, kwargs)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 176, in handle
callable_obj(*args, **kwargs)
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 196, in __exit__
self.obj_name))
File "C:\Users\vca\AppData\Local\Continuum\anaconda3\envs\skratch\lib\unittest\case.py", line 134, in _raiseFailure
raise self.test_case.failureException(msg)
AssertionError: ValueError not raised by fit
Anybody who successfully created a classifier that passes the test?
Upvotes: 2
Views: 2519
Reputation: 618
I just found out how to fix this. One needs to run check_classification_targets
in the fit. This apparently raises an error if regression targets are used.
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.metrics import euclidean_distances
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_is_fitted, check_array
class MyCustomClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, param1=2):
self.param1 = param1
def fit(self, X, y=None, **kwargs):
# Check that X and y have correct shape
X, y = check_X_y(X, y)
check_classification_targets(y)
# Store the classes seen during fit
self.classes_ = np.unique(y)
self.X_ = X
self.y_ = y
return self
def predict(self, X):
# Check is fit had been called
check_is_fitted(self, ['X_', 'y_', 'classes_'])
# Input validation
X = check_array(X)
closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
return self.y_[closest]
Upvotes: 3