Reputation: 415
This question refers to Obtaining summary shap plot for catboost model with tidymodels in R. Given the comment below the question, the OP found a solution but did not share it with the community so far.
I want to analyze my tree ensembles fitted with the tidymodels
package with SHAP value plots such as plots for single observations like
and to summarize the effect of all features of my dataset like
DALEXtra
provides a function to create SHAP values for tidymodels explain.tidymodels()
. force_plot
from the fastshap
package provide a wrapper for the plot function of the underlying python package SHAP
. But I can't understand how to make the function work with the output of the explain.tidymodels()
function.
Question : How can one generate such SHAP plots in R using tidymodels
and explain.tidymodels
?
MWE (for SHAP values with explain.tidymodels
)
library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <- recipe(crim ~ ., data = Boston)
split <- initial_split(Boston)
train_data <- training(split)
test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()
model_default<-
parsnip::boost_tree(
mode = "regression"
) %>%
set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")
model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>% add_recipe(rec) %>% {parsnip::fit(object = ., data = train_data)}
SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data
Upvotes: 9
Views: 2617
Reputation: 76
Perhaps this will help. At the very least, it is a step in the right direction.
First, ensure you have fastshap and reticulate installed (i.e., install.packages("...")). Next, set up a virtual environment and install shap (pip install ...). Also, install matplotlib 3.2.2 for the dependency plots (check out GitHub issues on this -- an older version of matplotlib is necessary).
RStudio has great information on virtual environment setup. That said, virtual environment setup requires more or less troubleshooting depending on the IDE of use. (Sadly, some work settings restrict the use of open source RStudio due to licensing.)
Docs for library(fastshap) are also helpful on this front.
Here's a workflow for lightgbm (from treesnip docs, lightly modified).
library(tidymodels)
library(treesnip)
data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)
# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)
model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")
# model specs
lightgbm_model <- model_spec %>%
set_engine("lightgbm", nthread = 6)
#workflows
lightgbm_wf <- workflow() %>%
add_model(
lightgbm_model
)
rec_ordered <- recipe(
price ~ .
, data = diamonds
)
lightgbm_fit_ordered <- fit_resamples(
add_recipe(
lightgbm_wf, rec_ordered
), resamples = diamonds_splits)
Prior to prediction we want to fit our workflow
fit_workflow <- lightgbm_wf %>%
add_recipe(rec_ordered) %>%
fit(data = diamonds)
Now we have a fit workflow and can predict. To use the fastshap::explain function, we need to create a predict function (this doesn't always hold: depending on the engine used it may or may not work out of the box -- see docs).
predict_function_gbm <- function(model, newdata) {
predict(model, newdata) %>% pluck(.,1)
}
Let's get the mean prediction value (used below) while we're at it. This also serves as a check to ensure the function is functioning.
mean_preds <- mean(
predict_function_gbm(
fit_workflow, diamonds %>% select(-price)
)
)
Now we create our explanations (shap values). Note the pred_wrapper and X arguments here (see fastshap github issues for other examples -- i.e. glmnet).
fastshap::explain(
fit_workflow,
X = as.data.frame(diamonds %>% select(-price)),
pred_wrapper = predict_function_gbm,
nsim = 10
) -> explanations_gbm
This should produce a force plot.
fastshap::force_plot(
object = explanations_gbm[1,],
feature_values = as.data.frame(diamonds %>% select(-price))[1,],
display = "viewer",
baseline = mean_preds)
This allows multiple, vertically stacked:
fastshap::force_plot(
object = explanations_gbm[1:20,],
feature_values = as.data.frame(diamonds %>% select(-price))[1:20,],
display = "viewer",
baseline = mean_preds)
Add link = "logit" for classification. Change display to "html" for Rmarkdown rendering.
Now for summary plots and dependency plots.
The trick is using reticulate to access the functions directly. Note that the same logic hold for libraries like transformers, numpy, etc.
First, for dependency plot.
library(reticulate)
shap = import("shap")
np = import("numpy")
shap$dependence_plot(
"rank(3)",
data.matrix(explanations_gbm),
data.matrix(diamond %>% select(-price))
)
See shap docs for explanation of rank(3) -- rank(1) etc will also work.
Unforunately it threw an error when I attempted naming the feature directly (i.e., "cut").
Now for the summary plot:
shap$summary_plot(
data.matrix(explanations_gbm),
data.matrix(diamond %>% select(-price))
)
Final note: rendering the plot repeatedly will produce buggy visualizations. Hopefully this provides a point of depature for catboost visualizations.
Upvotes: 6