Reputation: 117
I'm trying to use SHAP Kernel Explainer for a model imported from WEKA or done by Python-Weka-Wrapper3, i have a class that makes this models useful for Python libraries:
class weka_classifier(BaseEstimator, ClassifierMixin):
def __init__(self, classifier = None, dataset = None):
if classifier is not None:
self.classifier = classifier
if dataset is not None:
self.dataset = dataset
self.dataset.class_is_last()
if index is not None:
self.index = index
def fit(self, X, y):
return self.fit2()
def fit2(self):
return self.classifier.build_classifier(self.dataset)
def predict_instance(self,x):
x.append(0.0)
inst = Instance.create_instance(x,classname='weka.core.DenseInstance', weight=1.0)
inst.dataset = self.dataset
return self.classifier.classify_instance(inst)
def predict_proba_instance(self,x):
x.append(0.0)
inst = Instance.create_instance(x,classname='weka.core.DenseInstance', weight=1.0)
inst.dataset = self.dataset
return self.classifier.distribution_for_instance(inst)
def predict_proba(self,X):
prediction = []
for i in range(X.shape[0]):
instance = []
for j in range(X.shape[1]):
instance.append(X[i][j])
instance.append(0.0)
instance = Instance.create_instance(instance,classname='weka.core.DenseInstance', weight=1.0)
instance.dataset=self.dataset
prediction.append(self.classifier.distribution_for_instance(instance))
return np.asarray(prediction)
def predict(self,X):
prediction = []
for i in range(X.shape[0]):
instance = []
for j in range(X.shape[1]):
instance.append(X[i][j])
instance.append(0.0)
instance = Instance.create_instance(instance,classname='weka.core.DenseInstance', weight=1.0)
instance.dataset=self.dataset
prediction.append(self.classifier.classify_instance(instance))
return np.asarray(prediction)
def set_data(self,dataset):
self.dataset = dataset
self.dataset.class_is_last()
And that works for me with small datasets, something like that:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 260 entries, 0 to 259
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 BMI 260 non-null float64
1 ROM-PADF-KE_D 260 non-null int64
2 Asym-ROM-PHIR(≥8)_discr 260 non-null int64
3 Asym_SLCMJLanding-pVGRF(10percent)_discr 260 non-null int64
4 Asym_TJ_Valgus_FPPA(10percent)_discr 260 non-null int64
5 DVJ_Valgus_KneeMedialDisplacement_D_discr 260 non-null int64
dtypes: float64(1), int64(5)
when I do that, it works
explainer_num = shap.KernelExplainer(sci_Model_1.predict, X_num)
shap_values_num = explainer_num.shap_values(X_num)
taking about 19 minutes more or less
WARNING:shap:Using 260 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
100%
260/260 [18:54<00:00, 4.77s/it]
so it's cool, but when I try with a bigger dataset, not too big, but a little more, something like that:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 260 entries, 0 to 259
Data columns (total 62 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 DVJ_Valgus_KneeMedialDisplacement_D_discr 260 non-null int64
1 BMI 260 non-null float64
2 AgeGroup 260 non-null object
3 ROM-PADF-KE_D 260 non-null int64
4 DVJ_Valgus_FPPA_D_discr 260 non-null int64
5 TrainFrequency 260 non-null int64
6 DVJ_Valgus_FPPA_ND_discr 260 non-null int64
7 Asym_SLCMJLanding-pVGRF(10percent)_discr 260 non-null object
8 Asym-ROM-PHIR(≥8)_discr 260 non-null object
9 Asym_TJ_Valgus_FPPA(10percent)_discr 260 non-null object
10 TJ_Valgus_FPPA_ND_discr 260 non-null int64
11 Asym-ROM-PHF-KE(≥8)_discr 260 non-null object
12 TJ_Valgus_FPPA_D_discr 260 non-null int64
13 Asym_SLCMJ-Height(10percent)_discr 260 non-null object
14 Asym_YBTpl(10percent)_discr 260 non-null object
15 Position 260 non-null object
16 Asym-ROM-PADF-KE(≥8º)_discr 260 non-null object
17 DVJ_Valgus_KneeMedialDisplacement_ND_discr 260 non-null int64
18 DVJ_Valgus_Knee-to-ankle-ratio_discr 260 non-null object
19 Asym-ROM-PKF(≥8)_discr 260 non-null object
20 Asym-ROM-PHABD(≥8)_discr 260 non-null object
21 Asym-ROM-PHF-KF(≥8)_discr 260 non-null object
22 Asym-ROM-PHER(≥8)_discr 260 non-null object
23 AsymYBTanterior10percentdiscr 260 non-null object
24 Asym-ROM-PHABD-HF(≥8)_discr 260 non-null object
25 Asym-ROM-PHE(≥8)_discr 260 non-null object
26 Asym(>4cm)-DVJ_Valgus_Knee;edialDisplacement_discr 260 non-null object
27 Asym_SLCMJTakeOff-pVGRF(10percent)_discr 260 non-null object
28 Asym-ROM-PHADD(≥8)_discr 260 non-null object
29 Asym-YBTcomposite(10percent)_discr 260 non-null object
30 Asym_SingleHop(10percent)_discr 260 non-null object
31 Asym_YBTpm(10percent)_discr 260 non-null object
32 Asym_DVJ_Valgus_FPPA(10percent)_discr 260 non-null object
33 Asym_SLCMJ-pLFT(10percent)_discr 260 non-null object
34 DominantLeg 260 non-null object
35 Asym-ROM-PADF-KF(≥8)_discr 260 non-null object
36 ROM-PHER_ND 260 non-null int64
37 CPRDmentalskills 260 non-null object
38 POMStension 260 non-null int64
39 STAI-R 260 non-null float64
40 ROM-PHER_D 260 non-null int64
41 ROM-PHIR_D 260 non-null int64
42 ROM-PADF-KF_ND 260 non-null int64
43 ROM-PADF-KF_D 260 non-null int64
44 Age_at_PHV 260 non-null float64
45 ROM-PHIR_ND 260 non-null int64
46 CPRDtcohesion 260 non-null object
47 Eperience 260 non-null float64
48 ROM-PHABD-HF_D 260 non-null int64
49 MaturityOffset 260 non-null float64
50 Weight 260 non-null float64
51 ROM-PHADD_ND 260 non-null int64
52 Height 260 non-null float64
53 ROM-PHADD_D 260 non-null int64
54 Age 260 non-null float64
55 POMSdepressio 260 non-null int64
56 ROM-PADF-KE_ND 260 non-null int64
57 POMSanger 260 non-null int64
58 YBTanterior_Dnorm 260 non-null float64
59 YBTanterior_NDnorm 260 non-null float64
60 POMSvigour 260 non-null int64
61 Soft-Tissue_injury_≥4days 260 non-null object
dtypes: float64(10), int64(22), object(30)
I execute the kernel explainer like before but always it stops with the message "The kernel appears to have died"
0%| | 0/260 [00:00<?, ?it/s]
I tried to reduce the number of instances but it works sometimes and sometimes not, so it's not a solution I think
0%| | 0/5 [00:00<?, ?it/s]
So, What should I do? now I'll try to reduce the number of colums maybe and later try to put together, or reduce the number of instances with shap.sanple(X,...), but don't know if it will work, if someone know any solution I'll apreciate it.
Than you guys :)
EDIT:
As @fracpete and @Sergey Bushmanov said I try to increase the max_heap_size of the jvm with:
jvm.start(system_cp=True, packages=True, max_heap_size="10g")
and do a sample with 10 datapoints with:
explainer_3 = shap.KernelExplainer(sci_Model_3.predict, shap.sample(X_test,10))
shap_values_3 = explainer_3.shap_values(shap.sample(X_test,10))
But it didn't work for this problem, the kernel continue dying, any other solution ?
Thanks guys :)
Upvotes: 0
Views: 2805