Tengku Hanis
Tengku Hanis

Reputation: 87

Remove variable in model_parts() plot

I want to remove certain variables from the plot.

# Packages 
library(tidymodels)
library(mlbench)

# Data 
data("PimaIndiansDiabetes")
dat <- PimaIndiansDiabetes 
dat$some_new_group[1:384] <- "group 1" 
dat$some_new_group[385:768] <- "group 2"

# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)

# Recipes
svm_rec <- 
  recipe(diabetes ~., data = dat_train) %>% 
  update_role(some_new_group, new_role = "group_var") %>% 
  step_rm(pressure) %>% 
  step_YeoJohnson(all_numeric_predictors())
    
# Model spec 
svm_spec <- 
  svm_rbf() %>% 
  set_mode("classification") %>% 
  set_engine("kernlab")

# Workflow 
svm_wf <- 
  workflow() %>% 
  add_recipe(svm_rec) %>% 
  add_model(svm_spec)

# Train
svm_trained <- 
  svm_wf %>% 
  fit(dat_train)

# Explainer
library(DALEXtra)

svm_exp <- explain_tidymodels(svm_trained, 
                              data = dat %>% select(-diabetes), 
                              y = dat$diabetes %>% as.numeric(), 
                              label = "SVM")
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance") 
svm_vp

plot(svm_vp) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

enter image description here

Notice in the recipes above, I removed variable pressure and make a new categorical variable (some_new_group).

So, I can remove the variable pressure some_new_group from the plot manually like this:

plot(svm_vp %>% filter(variable != c("pressure", "some_new_group"))) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

enter image description here

But, is it possible to remove the variables when I run explain_tidymodels() or model_parts()?

Upvotes: 0

Views: 377

Answers (1)

Julia Silge
Julia Silge

Reputation: 11623

If you have variables that are not predictors or outcomes handled by your workflow() (like the variable you remove and your grouping variable), you want to make sure you only pass outcomes and predictors to explain_tidymodels(). You'll also need to build the explainer with the parsnip model, rather than the workflow() which is expecting to handle those non-outcome, non-predictor variables:

library(tidymodels)

# Data 
data("PimaIndiansDiabetes", package = "mlbench")
dat <- PimaIndiansDiabetes 
dat$some_new_group[1:384] <- "group 1" 
dat$some_new_group[385:768] <- "group 2"

# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)

# Recipes
svm_rec <- 
  recipe(diabetes ~., data = dat_train) %>% 
  update_role(some_new_group, new_role = "group_var") %>% 
  step_rm(pressure) %>% 
  step_YeoJohnson(all_numeric_predictors())

# Model spec 
svm_spec <- 
  svm_rbf() %>% 
  set_mode("classification") %>% 
  set_engine("kernlab")

# Train
svm_trained <- 
  workflow(svm_rec, svm_spec) %>% 
  fit(dat_train)

# Explainer
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain

svm_exp <- explain_tidymodels(
  extract_fit_parsnip(svm_trained), 
  data = svm_rec %>% prep() %>% bake(new_data = NULL, all_predictors()), 
  y = dat_train$diabetes %>% as.numeric(), 
  label = "SVM"
)
#> Preparation of a new explainer is initiated
#>   -> model label       :  SVM 
#>   -> data              :  576  rows  7  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  576  values 
#>   -> predict function  :  yhat.model_fit  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package parsnip , ver. 0.2.1 , task classification (  default  ) 
#>   -> predicted values  :  numerical, min =  0.08057345 , mean =  0.3540662 , max =  0.9357536  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  0.1083522 , mean =  0.9948921 , max =  1.895405  
#>   A new explainer has been created!

# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance") 
svm_vp
#>       variable mean_dropout_loss label
#> 1 _full_model_         0.6861190   SVM
#> 2      glucose         0.5919956   SVM
#> 3         mass         0.6673947   SVM
#> 4     pregnant         0.6700007   SVM
#> 5          age         0.6701185   SVM
#> 6     pedigree         0.6702812   SVM
#> 7      triceps         0.6760106   SVM
#> 8      insulin         0.6777355   SVM
#> 9   _baseline_         0.5020752   SVM

plot(svm_vp) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

Created on 2022-05-03 by the reprex package (v2.0.1)

If you have these "extra" variables in your workflow that shouldn't be used for explainability, then you'll need to do some extra work and can't rely on the workflow() alone.

Upvotes: 1

Related Questions