ste
ste

Reputation: 458

Tabulate accuracy and mean for each fold in GridSearchCV from scikit-learn

I'm doing a grid search over my model in scikit-learn, Python3, with two parameter sets A and B. The code looks like this:

parameterA = ['a', 'b']
parameterB = np.array([10, 100])
param_grid = dict(parameterA=parameterA, parameterB=parameterB)
model = buildModel()
grid = GridSearchCV(model, param_grid, scoring="accuracy")
grid_result = grid.fit(X, Y)
for parameters, scores in grid_result.grid_scores_:
    print("Mean: " + scores.mean())
    print("Parameters: " + parameters)

Upvotes: 3

Views: 4636

Answers (1)

MMF
MMF

Reputation: 5929

First, you should not use grid_scores_ anymore since it was deprecated in version 0.18 in favor of cv_results_ attribute. The grid_scores_ attribute will not be available from version 0.20.


: Did I understand correctly, that the score.mean() is the mean of the accuracies?

A : The attribute cv_results_ actually returns a dictionnary of all the metrics you are looking for. Check this out : cv_result_.


: Is it possible to get all those values for each fold of the cross validation? By default, there are k=3 folds, so I'd expect three times a mean and an accuracy for each parameter combination.

A : Yes, actually you'll need to use the attribute verbose. verbose must be an integer and it controls the verbosity: the higher, the more messages. For instance you could set verbose=3.


: How can I put in my own scoring function?

A : Use make_scorer after you've defined a loss function. Your loss function must have the following signature : score_func(y, y_pred, **kwargs). A basic loss function could be the ratio of well classified samples to the number of total samples (you can imagine any kind of metrics that would give you a good idea of how your classifier performs).

You would do like this :

def my_loss_func(y, y_pred):
    return np.sum(y == y_pred)/float(len(y_pred))

my_scorer = make_scorer(my_loss_func, greater_is_better=True)

And then you can use your scorer in your GridSearch.

Upvotes: 4

Related Questions