Tartaglia
Tartaglia

Reputation: 1051

Custom Scoring Function in sklearn Cross Validate

I would like to use a custom function for cross_validate which uses a specific y_test to compute precision, this is a different y_test than the actual target y_test.

I have tried a few approaches with make_scorer but I don't know how to actually pass my alternative y_test:

scoring = {'prec1': 'precision',
     'custom_prec1': make_scorer(precision_score()}

scores = cross_validate(pipeline, X, y, cv=5,scoring= scoring)

Can any suggest an approach?

Upvotes: 10

Views: 8328

Answers (1)

andrewchauzov
andrewchauzov

Reputation: 1019

Found this way. Maybe the code is not optimal, sorry for this.

Okay, let we start:

import numpy as np
import pandas as pd

from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer

xTrain = np.random.rand(100, 100)
yTrain = np.random.randint(1, 4, (100, 1))

yTrainCV = np.random.randint(1, 4, (100, 1))

model = LogisticRegression()

yTrainCV will be used here as the custom scorer.

def customLoss(xArray, yArray):
    indices = xArray.index.values
    tempArray = [1 if value1 != value2 else 0 for value1, value2 in zip(xArray.values, yTrainCV[[indices]])]
    
    return sum(tempArray)

scorer = {'main': 'accuracy',
          'custom': make_scorer(customLoss, greater_is_better=True)}

Few tricks here:

  • you need to pass to customLoss 2 values (predictions from the model + real values; we do not use the second parameter though)
  • there is some game with greater_is_better: True/False will return either positive or negative number
  • indices we get from CV in GridSearchCV

And...

grid = GridSearchCV(model,
                    scoring=scorer,
                    cv=5,
                    param_grid={'C': [1e0, 1e1, 1e2, 1e3],
                                'class_weight': ['balanced', None]},
                    refit='custom')
    
 grid.fit(xTrain, pd.DataFrame(yTrain))
 print(grid.score(xTrain, pd.DataFrame(yTrain)))
  • do not forget refit parameter in GridSearchCV
  • we pass target array as DataFrame here - it will help us to detect indices in the custom loss function

Upvotes: 12

Related Questions