TracNav
TracNav

Reputation: 45

Python sklearn: why must I set up a new estimator to plot a learning curve?

I'm using GridSearchCV to tune an SVM classifier, then plot a learning curve. However, unless I set up a fresh classifier before plotting the learning curve, I run into an IndexError and I'm not sure quite as to why.

My CV / classifier set up is below:

# Set up classifier
clf_untuned = OneVsRestClassifier(SVC(kernel='rbf', random_state=0, max_iter=1000))
cv = cross_validation.ShuffleSplit(data_image.shape[1], n_iter=10,
                                       test_size=0.1, random_state=0)

# Use cross validation / grid search to find optimal hyperparameters
if TRAINING_CROSS_VALIDATION == 1:
    params = {
        ...
    }
    clf_tuned = GridSearchCV(clf_untuned, cv=cv, param_grid=params)
    clf_tuned.fit(x_train, y_train)
    print('Best parameters: %s' % clf_tuned.best_params_)
else:
    clf_tuned = OneVsRestClassifier(SVC(kernel='rbf',
                                        C=100, gamma=0.00001, random_state=0, verbose=0))
    clf_tuned.fit(x_train, y_train)

I then go on to plot the learning curve, where plot_learning_curve duplicates the sklearn example (http://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html). If I use the following code, then I get the following error at the 'learning_curve' line in plot_learning_curve:

# Plot learning curve for best params -- yields IndexError
plot_learning_curve(clf_tuned, title, x_train, y_train, ylim=(0.6, 1.05), cv=cv)

IndexError: index 663 is out of bounds for size 70

However if instead I start a new classifer then everything works OK:

# Plot learning curve for best params -- functions correctly
estimator = OneVsRestClassifier(SVC(kernel='rbf',
                                        C=100, gamma=0.00001, random_state=0, verbose=0))
plot_learning_curve(estimator, title, x_train, y_train, ylim=(0.6, 1.05), cv=cv)

Why is this? Many thanks in advance, and other comments on my questionable implementation are welcome.

Upvotes: 1

Views: 1618

Answers (1)

Prateek
Prateek

Reputation: 429

The problem was resolved by passing the best estimator obtained via grid search as clf_tuned.best_estimator_

Upvotes: 1

Related Questions