Pablo Moreira Garcia
Pablo Moreira Garcia

Reputation: 117

Kernel dies when using shap.KernelExplainer()

I'm trying to use SHAP Kernel Explainer for a model imported from WEKA or done by Python-Weka-Wrapper3, i have a class that makes this models useful for Python libraries:

class weka_classifier(BaseEstimator, ClassifierMixin):
    
    def __init__(self, classifier = None, dataset = None):
        if classifier is not None:
            self.classifier = classifier
        if dataset is not None:
            self.dataset = dataset
            self.dataset.class_is_last()
        if index is not None:
            self.index = index
               
    def fit(self, X, y):
        return self.fit2()
    
    def fit2(self):
        return self.classifier.build_classifier(self.dataset)
    
    def predict_instance(self,x):
        x.append(0.0)
        inst = Instance.create_instance(x,classname='weka.core.DenseInstance', weight=1.0)
        inst.dataset = self.dataset
        
        return self.classifier.classify_instance(inst)
    
    def predict_proba_instance(self,x):
        x.append(0.0)
        inst = Instance.create_instance(x,classname='weka.core.DenseInstance', weight=1.0)
        inst.dataset = self.dataset
        
        return self.classifier.distribution_for_instance(inst)
    
    def predict_proba(self,X):
        prediction = []

        for i in range(X.shape[0]):
            instance = []
            for j in range(X.shape[1]):
                instance.append(X[i][j])
            instance.append(0.0)
            instance = Instance.create_instance(instance,classname='weka.core.DenseInstance', weight=1.0)
            instance.dataset=self.dataset
            prediction.append(self.classifier.distribution_for_instance(instance))

        return np.asarray(prediction)    
    
    def predict(self,X):
        prediction = []
        
        for i in range(X.shape[0]):
            instance = []
            for j in range(X.shape[1]):
                instance.append(X[i][j])
            instance.append(0.0)
            instance = Instance.create_instance(instance,classname='weka.core.DenseInstance', weight=1.0)
            instance.dataset=self.dataset
            prediction.append(self.classifier.classify_instance(instance))
            
        return np.asarray(prediction)
    

    def set_data(self,dataset):
        self.dataset = dataset
        self.dataset.class_is_last()

And that works for me with small datasets, something like that:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 260 entries, 0 to 259
Data columns (total 6 columns):
 #   Column                                     Non-Null Count  Dtype  
---  ------                                     --------------  -----  
 0   BMI                                        260 non-null    float64
 1   ROM-PADF-KE_D                              260 non-null    int64  
 2   Asym-ROM-PHIR(≥8)_discr                    260 non-null    int64  
 3   Asym_SLCMJLanding-pVGRF(10percent)_discr   260 non-null    int64  
 4   Asym_TJ_Valgus_FPPA(10percent)_discr       260 non-null    int64  
 5   DVJ_Valgus_KneeMedialDisplacement_D_discr  260 non-null    int64  
dtypes: float64(1), int64(5)

when I do that, it works

explainer_num = shap.KernelExplainer(sci_Model_1.predict, X_num)
shap_values_num = explainer_num.shap_values(X_num)

taking about 19 minutes more or less

WARNING:shap:Using 260 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.

100%
260/260 [18:54<00:00, 4.77s/it]

so it's cool, but when I try with a bigger dataset, not too big, but a little more, something like that:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 260 entries, 0 to 259
Data columns (total 62 columns):
 #   Column                                              Non-Null Count  Dtype  
---  ------                                              --------------  -----  
 0   DVJ_Valgus_KneeMedialDisplacement_D_discr           260 non-null    int64  
 1   BMI                                                 260 non-null    float64
 2   AgeGroup                                            260 non-null    object 
 3   ROM-PADF-KE_D                                       260 non-null    int64  
 4   DVJ_Valgus_FPPA_D_discr                             260 non-null    int64  
 5   TrainFrequency                                      260 non-null    int64  
 6   DVJ_Valgus_FPPA_ND_discr                            260 non-null    int64  
 7   Asym_SLCMJLanding-pVGRF(10percent)_discr            260 non-null    object 
 8   Asym-ROM-PHIR(≥8)_discr                             260 non-null    object 
 9   Asym_TJ_Valgus_FPPA(10percent)_discr                260 non-null    object 
 10  TJ_Valgus_FPPA_ND_discr                             260 non-null    int64  
 11  Asym-ROM-PHF-KE(≥8)_discr                           260 non-null    object 
 12  TJ_Valgus_FPPA_D_discr                              260 non-null    int64  
 13  Asym_SLCMJ-Height(10percent)_discr                  260 non-null    object 
 14  Asym_YBTpl(10percent)_discr                         260 non-null    object 
 15  Position                                            260 non-null    object 
 16  Asym-ROM-PADF-KE(≥8º)_discr                         260 non-null    object 
 17  DVJ_Valgus_KneeMedialDisplacement_ND_discr          260 non-null    int64  
 18  DVJ_Valgus_Knee-to-ankle-ratio_discr                260 non-null    object 
 19  Asym-ROM-PKF(≥8)_discr                              260 non-null    object 
 20  Asym-ROM-PHABD(≥8)_discr                            260 non-null    object 
 21  Asym-ROM-PHF-KF(≥8)_discr                           260 non-null    object 
 22  Asym-ROM-PHER(≥8)_discr                             260 non-null    object 
 23  AsymYBTanterior10percentdiscr                       260 non-null    object 
 24  Asym-ROM-PHABD-HF(≥8)_discr                         260 non-null    object 
 25  Asym-ROM-PHE(≥8)_discr                              260 non-null    object 
 26  Asym(>4cm)-DVJ_Valgus_Knee;edialDisplacement_discr  260 non-null    object 
 27  Asym_SLCMJTakeOff-pVGRF(10percent)_discr            260 non-null    object 
 28  Asym-ROM-PHADD(≥8)_discr                            260 non-null    object 
 29  Asym-YBTcomposite(10percent)_discr                  260 non-null    object 
 30  Asym_SingleHop(10percent)_discr                     260 non-null    object 
 31  Asym_YBTpm(10percent)_discr                         260 non-null    object 
 32  Asym_DVJ_Valgus_FPPA(10percent)_discr               260 non-null    object 
 33  Asym_SLCMJ-pLFT(10percent)_discr                    260 non-null    object 
 34  DominantLeg                                         260 non-null    object 
 35  Asym-ROM-PADF-KF(≥8)_discr                          260 non-null    object 
 36  ROM-PHER_ND                                         260 non-null    int64  
 37  CPRDmentalskills                                    260 non-null    object 
 38  POMStension                                         260 non-null    int64  
 39  STAI-R                                              260 non-null    float64
 40  ROM-PHER_D                                          260 non-null    int64  
 41  ROM-PHIR_D                                          260 non-null    int64  
 42  ROM-PADF-KF_ND                                      260 non-null    int64  
 43  ROM-PADF-KF_D                                       260 non-null    int64  
 44  Age_at_PHV                                          260 non-null    float64
 45  ROM-PHIR_ND                                         260 non-null    int64  
 46  CPRDtcohesion                                       260 non-null    object 
 47  Eperience                                           260 non-null    float64
 48  ROM-PHABD-HF_D                                      260 non-null    int64  
 49  MaturityOffset                                      260 non-null    float64
 50  Weight                                              260 non-null    float64
 51  ROM-PHADD_ND                                        260 non-null    int64  
 52  Height                                              260 non-null    float64
 53  ROM-PHADD_D                                         260 non-null    int64  
 54  Age                                                 260 non-null    float64
 55  POMSdepressio                                       260 non-null    int64  
 56  ROM-PADF-KE_ND                                      260 non-null    int64  
 57  POMSanger                                           260 non-null    int64  
 58  YBTanterior_Dnorm                                   260 non-null    float64
 59  YBTanterior_NDnorm                                  260 non-null    float64
 60  POMSvigour                                          260 non-null    int64  
 61  Soft-Tissue_injury_≥4days                           260 non-null    object 
dtypes: float64(10), int64(22), object(30)

I execute the kernel explainer like before but always it stops with the message "The kernel appears to have died"

  0%|          | 0/260 [00:00<?, ?it/s]

I tried to reduce the number of instances but it works sometimes and sometimes not, so it's not a solution I think

  0%|          | 0/5 [00:00<?, ?it/s]

So, What should I do? now I'll try to reduce the number of colums maybe and later try to put together, or reduce the number of instances with shap.sanple(X,...), but don't know if it will work, if someone know any solution I'll apreciate it.

Than you guys :)

EDIT:

As @fracpete and @Sergey Bushmanov said I try to increase the max_heap_size of the jvm with:

jvm.start(system_cp=True, packages=True, max_heap_size="10g")

and do a sample with 10 datapoints with:

explainer_3 = shap.KernelExplainer(sci_Model_3.predict, shap.sample(X_test,10))
shap_values_3 = explainer_3.shap_values(shap.sample(X_test,10))

But it didn't work for this problem, the kernel continue dying, any other solution ?

Thanks guys :)

Upvotes: 0

Views: 2805

Answers (0)

Related Questions