Joe
Joe

Reputation: 395

Extract feature importance per class from SHAP summary plot from a multi-class problem

I would like to know how to generate a table for feature importance for a specific class using the shap algorithm?

enter image description here

From the plot above, how to extract the feature importance for just class 6?

I saw here that for a binary class problem you can extract the per class shap via:

# shap values for survival
sv_survive = sv[:,y,:]
# shap values for dying
sv_die = sv[:,~y,:]

How to conform this code to work for a multiclass problem?

I need to extract the shap values in relation to the feature importance for class 6.

Here is the beginning of my code:

from sklearn.datasets import make_classification
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import pickle
import joblib
import warnings
import shap
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV

f, (ax1,ax2) = plt.subplots(nrows=1, ncols=2,figsize=(20,8))
# Generate noisy Data
X_train,y_train = make_classification(n_samples=1000, 
                          n_features=50, 
                          n_informative=9, 
                          n_redundant=0, 
                          n_repeated=0, 
                          n_classes=10, 
                          n_clusters_per_class=1,
                          class_sep=9,
                          flip_y=0.2,
                          #weights=[0.5,0.5], 
                          random_state=17)

X_test,y_test = make_classification(n_samples=500, 
                          n_features=50, 
                          n_informative=9, 
                          n_redundant=0, 
                          n_repeated=0, 
                          n_classes=10, 
                          n_clusters_per_class=1,
                          class_sep=9,
                          flip_y=0.2,
                          #weights=[0.5,0.5], 
                          random_state=17)

model = RandomForestClassifier()

parameter_space = {
    'n_estimators': [10,50,100],
    'criterion': ['gini', 'entropy'],
    'max_depth': np.linspace(10,50,11),
}

clf = GridSearchCV(model, parameter_space, cv = 5, scoring = "accuracy", verbose = True) # model
my_model = clf.fit(X_train,y_train)
print(f'Best Parameters: {clf.best_params_}')

# save the model to disk
filename = f'Testt-RF.sav'
pickle.dump(clf, open(filename, 'wb'))

explainer = Explainer(clf.best_estimator_)
shap_values_tr1 = explainer.shap_values(X_train)

Upvotes: 3

Views: 1945

Answers (1)

Sergey Bushmanov
Sergey Bushmanov

Reputation: 25189

Let's try minimal reproducible example:

from sklearn.datasets import make_classification
from shap import Explainer, waterfall_plot, Explanation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Generate noisy Data
X, y = make_classification(n_samples=1000, 
                          n_features=50, 
                          n_informative=9, 
                          n_redundant=0, 
                          n_repeated=0, 
                          n_classes=10, 
                          n_clusters_per_class=1,
                          class_sep=9,
                          flip_y=0.2,
                          random_state=17)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

model = RandomForestClassifier()
model.fit(X_train, y_train)

explainer = Explainer(model)
sv = explainer.shap_values(X_test)

I'm stating you can reach you goal with:

cls = 9   # class to explain
sv_cls = sv[cls]

Why?

We should be able to explain a datapoint:

idx = 99  # datapoint to prove
pred = model.predict_proba(X_test[[idx]])[:, cls]
pred

array([0.01])

We can prove we're doing right visually:

waterfall_plot(Explanation(sv_cls[idx], explainer.expected_value[cls]))

enter image description here

and mathematically:

np.allclose(pred, explainer.expected_value[cls] + sv[cls][idx].sum())

True

Upvotes: 3

Related Questions