Shihab Shahriar Khan
Shihab Shahriar Khan

Reputation: 5485

Failure to reproduce scikit-learn and numpy dependent code when multiprocessing is used

The code below is completely reproducible when n_jobs=1 at cross_validate function, but not so when n_jobs=-1 or 2.

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_validate,RepeatedStratifiedKFold

class DecisionTree(DecisionTreeClassifier):
    def fit(self,X,Y):
        weight = np.random.uniform(size=Y.shape)
        return super().fit(X,Y,sample_weight=weight)

def main():
    X,Y = load_iris(return_X_y=True)
    rks = RepeatedStratifiedKFold(n_repeats=2,n_splits=5,random_state=42)
    clf = DecisionTree(random_state=42)
    res = cross_validate(clf,X,Y,cv=rks,n_jobs=2)['test_score']*100
    return res.mean(),res.std()

if __name__=='__main__':
    np.random.seed(42)
    print(main())

Please note the np.random.uniform call at fit function. The code is also completely reproducible without such numpy calls. It is mentioned here that numpy.random.seed is not thread-safe. But I saw no mention of this in sklearn's FAQ, according to which providing random_state everywhere should suffice.

Is there anyway to use both numpy random calls and multiprocessing in sklearn while maintaining full reproducibility?

EDIT: I think it reproduces fine if we put n_jobs>1 inside estimators that take it, while instantiating RandomForestClassifier for example.

Upvotes: 1

Views: 177

Answers (1)

Sam Mason
Sam Mason

Reputation: 16213

it would appear your DecisionTree class should be using the random_state that was passed in. I get consistent results when do:

from sklearn.utils import check_random_state

class DecisionTree(DecisionTreeClassifier):
    def fit(self, X, Y):
        rng = check_random_state(self.random_state)
        weight = rng.uniform(size=Y.shape)
        return super().fit(X, Y, sample_weight=weight)

but otherwise do as you do. note that with this change you can also remove the call to np.random.seed(42) as the RNG state is explicitly set everywhere it needs to be

Upvotes: 2

Related Questions