Murilo
Murilo

Reputation: 679

Accuracy and Confusion Matrix in Cross Validation

I am training a model to solve binary classification problem usign scikitlearn, and i wish to perform cross validation with 5 folds.

As metrics, i would like to get both the average accuracy and a confusion matrix over the 5 folds.

So, using cross_validate i can pass multiple metrics to the scoring parameter.

According to this link, i can def a function that returns the confusion matrix at each fold. In that piece of code, it uses X to predict some output through .predict(X). But shouldn't a test set, x_test, have been used instead? And since, at each fold, a different test set is obtained from cross_validate, i don't understand how we can just pass X to both confusion_matrix_scorer() and .predict(). Other question, is clf = svm here, right?

Upvotes: 1

Views: 1026

Answers (1)

Miguel Trejo
Miguel Trejo

Reputation: 6667

Docs state that a callable scorer should satisfy

It can be called with parameters (estimator, X, y), where estimator is the model that should be evaluated, X is validation data, and y is the ground truth target for X (in the supervised case) or None (in the unsupervised case).

When calling cross_validate, the cv folds are first generated and passed to independent fitting processes. Inside these processes, the test dataset is passed to a private _score method. From the source code

test_scores = _score(estimator, X_test, y_test, scorer, error_score)

which call the input scorrer with the defined parameters (estimator, X, y) source code

scores = scorer(estimator, X_test, y_test)

If you want to get both the average accuracy and a confusion matrix you can return these scores through a dictionary

Example code

from sklearn.metrics import accuracy_score, confusion_matrix

def confusion_matrix_scorer(clf, X, y):
      y_pred = clf.predict(X)
      cm = confusion_matrix(y, y_pred)
      acc = accuracy_score(y, y_pred)
      return {
          'acc': acc,
          'tn': cm[0, 0], 
          'fp': cm[0, 1],
          'fn': cm[1, 0], 
          'tp': cm[1, 1]
      }

Upvotes: 3

Related Questions