dingaro
dingaro

Reputation: 2342

How to make SHAP summary_plot only for selected features from list in Python?

I try to make SHAP summary plot in Python only for selected features from my ML model.

Generally we can make SHAP summary plot like below:

import shap

model = clf
explainer = shap.Explainer(model)
shap_values = explainer(X_test)

shap.summary_plot(shap_values, X_test)

But how can I do that only for selected features from X_test, for example for features from list like: my_list = ['val1', 'val2', 'val3']

Upvotes: 1

Views: 2162

Answers (1)

Chan Bulgin
Chan Bulgin

Reputation: 31

This can be done by excluding features from the inputs you send summary_plot as follows:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
shap.initjs()

model = clf
explainer = shap.Explainer(model)
shap_values = explainer(X_test)

features_list = ['col1','col13','col25']
shap_values_fl = pd.DataFrame(shap_values)
shap_values_fl = shap_values_fl[features_list]
X_test_fl = X_test[features_list]

shap.summary_plot(np.array(shap_values_fl), X_test_fl)

Upvotes: 3

Related Questions