Diogo Silva
Diogo Silva

Reputation: 330

Plot Learning Curve of CatBoostClassifier with Yellowbrick

I'm trying to plot a learning curve for the CatBoostClassifier. The error occurs when I fit the CatBoostClassifier into LearningCurve from yellowbrick. I think this should work since CatBoost is sklearn compatible and yellow brick is a sklearn extension.

Code snippet:

kf = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=0)
sizes = np.linspace(0.2, 1.0, 10)
estimator = CatBoostClassifier(
    iterations=42, learning_rate=0.3, max_depth=10)

visualizer = LearningCurve(
    estimator, cv=kf, scoring='accuracy', train_sizes=sizes, n_jobs=-1
)

visualizer.fit(X, y)
visualizer.show()

Error:

... yellowbrick.exceptions.YellowbrickTypeError: Cannot detect the model name for non estimator: ''

Any suggestions?

Upvotes: 1

Views: 2646

Answers (2)

Satrio Adi Prabowo
Satrio Adi Prabowo

Reputation: 600

You can use wrapper for third party estimator, more details. I’ve tried and it worked. Something like this:

from yellowbrick.classifier import ROCAUC
from yellowbrick.contrib.wrapper import wrap

catboost_model = CatBoostClassifier()
model = wrap(catboost_model)
visualizer = ROCAUC(model)
visualizer.fit(X_train, y_train)
visualizer.score(X_test, y_test)
visualizer.show()

Upvotes: 3

user14044591
user14044591

Reputation: 1

I could plot a learning curve for XGBClassifier, I think it should also work for CatBoostClassifier. visualizer.show() won't work though. Use visualizer.poof() to render the plot.

Upvotes: -1

Related Questions