mugdi
mugdi

Reputation: 415

Create SHAP plots for tidymodel objects

This question refers to Obtaining summary shap plot for catboost model with tidymodels in R. Given the comment below the question, the OP found a solution but did not share it with the community so far.

I want to analyze my tree ensembles fitted with the tidymodels package with SHAP value plots such as plots for single observations like

ttps://prnt.sc/CO_PC4aDUQA0

and to summarize the effect of all features of my dataset like

enter image description here

DALEXtra provides a function to create SHAP values for tidymodels explain.tidymodels(). force_plot from the fastshap package provide a wrapper for the plot function of the underlying python package SHAP. But I can't understand how to make the function work with the output of the explain.tidymodels() function.

Question : How can one generate such SHAP plots in R using tidymodels and explain.tidymodels?

MWE (for SHAP values with explain.tidymodels)

library(MASS)
library(tidyverse)
library(tidymodels)
library(parsnip)
library(treesnip)
library(catboost)
library(fastshap)
library(DALEXtra)
set.seed(1337)
rec <-  recipe(crim ~ ., data = Boston)

split <- initial_split(Boston)

train_data <- training(split)

test_data <- testing(split) %>% dplyr::select(-crim) %>% as.matrix()

model_default<-
  parsnip::boost_tree(
    mode = "regression"
  ) %>%
  set_engine(engine = 'catboost', loss_function = 'RMSE')
#sometimes catboost is not loaded correctly the following two lines
#ensure prevent fitting errors
#https://github.com/curso-r/treesnip/issues/21 error is mentioned on last post
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")

model_fit_wf <- model_fit_wf <- workflow() %>% add_model(model_tune) %>%  add_recipe(rec) %>% {parsnip::fit(object = ., data =  train_data)}

SHAP_wf <- explain_tidymodels(model_fit_wf, data = X, y = train_data$crim, new_data = test_data

Upvotes: 9

Views: 2617

Answers (1)

user18815063
user18815063

Reputation: 76

Perhaps this will help. At the very least, it is a step in the right direction.

First, ensure you have fastshap and reticulate installed (i.e., install.packages("...")). Next, set up a virtual environment and install shap (pip install ...). Also, install matplotlib 3.2.2 for the dependency plots (check out GitHub issues on this -- an older version of matplotlib is necessary).

RStudio has great information on virtual environment setup. That said, virtual environment setup requires more or less troubleshooting depending on the IDE of use. (Sadly, some work settings restrict the use of open source RStudio due to licensing.)

Docs for library(fastshap) are also helpful on this front.

Here's a workflow for lightgbm (from treesnip docs, lightly modified).

library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

# model specs
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 6)

#workflows
lightgbm_wf <- workflow() %>% 
    add_model(
       lightgbm_model
    )

rec_ordered <- recipe(
    price ~ .
      , data = diamonds
) 

lightgbm_fit_ordered <- fit_resamples(
  add_recipe(
    lightgbm_wf, rec_ordered
    ), resamples = diamonds_splits)

Prior to prediction we want to fit our workflow

fit_workflow <- lightgbm_wf %>% 
     add_recipe(rec_ordered) %>% 
     fit(data = diamonds)

Now we have a fit workflow and can predict. To use the fastshap::explain function, we need to create a predict function (this doesn't always hold: depending on the engine used it may or may not work out of the box -- see docs).

predict_function_gbm <-  function(model, newdata) {
    predict(model, newdata) %>% pluck(.,1)
}

Let's get the mean prediction value (used below) while we're at it. This also serves as a check to ensure the function is functioning.

mean_preds <- mean(
    predict_function_gbm(
       fit_workflow, diamonds %>% select(-price)
   )
)

Now we create our explanations (shap values). Note the pred_wrapper and X arguments here (see fastshap github issues for other examples -- i.e. glmnet).

fastshap::explain( 
    fit_workflow, 
    X = as.data.frame(diamonds %>% select(-price)),
    pred_wrapper = predict_function_gbm, 
    nsim = 10
) -> explanations_gbm

This should produce a force plot.

fastshap::force_plot(
    object = explanations_gbm[1,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1,], 
    display = "viewer", 
    baseline = mean_preds) 

This allows multiple, vertically stacked:

fastshap::force_plot(
    object = explanations_gbm[1:20,], 
    feature_values = as.data.frame(diamonds %>% select(-price))[1:20,], 
    display = "viewer", 
    baseline = mean_preds) 

Add link = "logit" for classification. Change display to "html" for Rmarkdown rendering.

Now for summary plots and dependency plots.

The trick is using reticulate to access the functions directly. Note that the same logic hold for libraries like transformers, numpy, etc.

First, for dependency plot.

library(reticulate)
shap = import("shap")
np = import("numpy") 

shap$dependence_plot(
     "rank(3)", 
     data.matrix(explanations_gbm),
     data.matrix(diamond %>% select(-price))
)

See shap docs for explanation of rank(3) -- rank(1) etc will also work.

Unforunately it threw an error when I attempted naming the feature directly (i.e., "cut").

Now for the summary plot:

shap$summary_plot( 
    data.matrix(explanations_gbm),
    data.matrix(diamond %>% select(-price))
)

Final note: rendering the plot repeatedly will produce buggy visualizations. Hopefully this provides a point of depature for catboost visualizations.

Upvotes: 6

Related Questions