Optimal threshold for imbalanced binar classification problem

i have trouble optimizing threshold for binar classification. I am using 3 models: Logistic Regression, Catboost and Sklearn RandomForestClassifier.

For each model I am doing the following steps:

1) fit model

2) get 0.0 recall for first class (which belongs to 5% of dataset) and 1.0 recall for zero class. (this can't be fixed with gridsearch and class_weight='balanced' parameter.) >:(

3) Find optimal treshold

fpr, tpr, thresholds = roc_curve(y_train, model.predict_proba(X_train)[:, 1])
optimal_threshold = thresholds[np.argmax(tpr - fpr)]

4) Enjoy ~70 recall ratio for both classes.

5) Predict probabilities for test dataset and use optimal_threshold, i calculated above, to get classes.

Here comes the question: when I am starting code again and again, if i don't fix random_state, optimal treshold is variant and shifts quiet dramatically. This leads to dramatic changes in accuracy metrics based on test sample.

Do i need to calculate some average threshold and use it as a constant hard value? Or maybe i have to fix random_state everywhere? Or maybe the method of finding optimal_threshold isnt correct?

Upvotes: 1

Views: 1025

Answers (1)

B200011011
B200011011

Reputation: 4258

If you do not set random_state to a fixed values results will be different in every run. To get reproducible results set random_state everywhere required to a fixed value or, use fixed numpy random seed numpy.random.seed.

https://scikit-learn.org/stable/faq.html#how-do-i-set-a-random-state-for-an-entire-execution

Scikit FAQ mentions it is better to use random_state where required instead of global random state.

Global Random State Example:

import numpy as np
np.random.seed(42)

Some examples locally setting random_state:

X_train, X_test, y_train, y_test = train_test_split(sample.data, sample.target, test_size=0.3, random_state=0)

skf =  StratifiedKFold(n_splits=10, random_state=0, shuffle=True)

classifierAlgorithm = LGBMClassifier(objective='binary', random_state=0)

Upvotes: 1

Related Questions