kris
kris

Reputation: 141

Shap values beeswarm plot of top predictors without "sum of others"

I want to produce a beeswarm plot of the top 15 predictors of my target as established by the shap values analysis. I'm using the R shapviz command for that:

plot=sv_importance(shp, kind="beeswarm", alpha=0.5, show_numbers=T, max_display=15)

However, the command also plots the Sum of x other shap values in the graph. Do you please have any suggestions about how the get rid of it?

Upvotes: 2

Views: 1182

Answers (2)

Michael M
Michael M

Reputation: 1593

Thanks to this question, "shapviz" 0.4.1 now has received a new option sv_importance(, show_other = FALSE) to control this:

library(shapviz)

set.seed(1)

X_train <- data.matrix(`colnames<-`(replicate(26, rnorm(100)), LETTERS))
dtrain <- xgboost::xgb.DMatrix(X_train, label = rnorm(100))
fit <- xgboost::xgb.train(data = dtrain, nrounds = 50)
shp <- shapviz(fit, X_pred = X_train)

sv_importance(
  shp, kind = "bee", show_other = FALSE, max_display = 15, show_numbers = TRUE
)

enter image description here

Upvotes: 1

Allan Cameron
Allan Cameron

Reputation: 173803

It's always easier to answer if we have a reproducible example, but I'll create one for the purposes of this answer:

library(shapviz)

set.seed(1)

X_train <- data.matrix(`colnames<-`(replicate(26, rnorm(100)), LETTERS))
dtrain <- xgboost::xgb.DMatrix(X_train, label = rnorm(100))
fit <- xgboost::xgb.train(data = dtrain, nrounds = 50)
shp <- shapviz(fit, X_pred = X_train)

Now we have a shapviz object called shp, with 26 features, so this should be comparable to your situation.

If we plot with sv_importance, we get a ggplot object where we can see the second value is an undesired row called "Sum of 12 other"

p <- sv_importance(shp, kind = "beeswarm", show_numbers = TRUE, max_display = 15) 

p

To get rid of this, we simply remove rows referring to this data point from the data in the plot object (it appears in two different data frames in the plot object, so we have to remove rows from 2 data frames)

p$data <- p$data[!grepl("Sum of", p$data$feature),]
p$layers[[3]]$data <-  p$layers[[3]]$data[!grepl("Sum of", p$layers[[3]]$data$feature),]

Now when we draw p again, the offending row is removed:

p

Created on 2022-10-30 with reprex v2.0.2

Upvotes: 2

Related Questions