geo_dd
geo_dd

Reputation: 313

How to use the pdp in R to compute 3d partial dependence plots?

I have a Random forest model in R similar to this:

library("randomForest")
library("caret")
library("pdp")
data("cars")
my_data<-cars[1:5]
my_rf <- randomForest( Price ~ ., data=my_data)
price_mil<- partial(my_rf, pred.var = c("Price", "Mileage"))
plotPartial(price_mil, levelplot = FALSE, zlab = "Price", colorkey = TRUE)

However, I would like to have some 3d partial dependence plots, including the values of parameters on the axis. How can I do this with pdp?

Upvotes: 2

Views: 2012

Answers (2)

gm1991
gm1991

Reputation: 303

Interactive 3D Partial Dependence Plot with plotly

# Random seed to reproduce the results
set.seed(1)

# Create artificial data for a binary classification problem 
y <- factor(sample(c(0,1), size = 100, replace = TRUE), levels = c("0", "1"))
d <- data.frame(y = y, x1 = rnorm(100), x2 = rnorm(100), x3 = rnorm(100))

# Build a random forest model
library(randomForest)
rf1 <- randomForest::randomForest(y ~., n.trees = 100, mtry = 2, data = d)

###### Bivariate partial dependency plots ######
# Step 1: compute the partial dependence values
# given two variables using the pdp library
library(pdp)
pd <- rf1 %>% partial(pred.var = c("x1", "x2"), n.trees = 100)

# Step 2: construct the plot using the plotly library 
library(plotly)
p <- plot_ly(x = pd$x1, y = pd$x2, z = pd$yhat, type = 'mesh3d')

# Step 3: add labels to the plot
p <- p %>% layout(scene = list(xaxis = list(title = "x1"),
                           yaxis = list(title = "x2"),
                           zaxis = list(title = "Partial Dependence")))

# Step 4: show the plot 
show(p)

enter image description here

Interactive Contour Plot (i.e. flattened 2-variable PDP) with a color scale for the partial dependence values using plotly

###### Bivariate PDPs with colored scale ######
# Interpolate the partial dependence values 
dens <- akima::interp(x = pd$x1, y = pd$x2, z = pd$yhat)

# Flattened contour partial dependence plot for 2 variables
p2 <- plot_ly(x = dens$x,
          y = dens$y, 
          z = dens$z, 
          colors = c("blue", "grey", "red"),
          type = "contour")

# Add axis labels for 2D plots
p2 <- p2 %>% layout(xaxis = list(title = "x1"), yaxis = list(title = "x2"))
# Show the plot
show(p2)

enter image description here

Interactive 3D Partial Dependence Plot with a color scale for the partial dependence values using plotly

###### Interactive 3D partial dependence plot with coloring scale ######

# Interpolate the partial dependence values
dens <- akima::interp(x = pd$x1, y = pd$x2, z = pd$yhat)

# 3D partial dependence plot with a coloring scale
p3 <- plot_ly(x = dens$x, 
          y = dens$y, 
          z = dens$z,
          colors = c("blue", "grey", "red"),
          type = "surface")
# Add axis labels for 3D plots
p3 <- p3 %>% layout(scene = list(xaxis = list(title = "x1"),
                             yaxis = list(title = "x2"),
                             zaxis = list(title = "Partial Dependence")))
# Show the plot
show(p3)

enter image description here

Upvotes: 3

pookpash
pookpash

Reputation: 898

First of all, In your example you used "price" in the partial() function. This does not make sense to me, as you essentially just plot a 2d partial dependence plot that way. I changed that in my example code below.

However, to get the requested partial plots you can use

plotPartial(price_mil, zlab = "Price", levelplot = F, scale = list(arrows = F))

If you want to have more control, I would advise to use the underlying functions of the package to construct your formula and wireframe object and then call wireframe() with scale=list(arrows = F) to add the values to the axes.

library("randomForest")
library("caret")
library("pdp")
data("cars")
my_data <- cars[1:5]
my_rf <- randomForest( Price ~ ., data=my_data)

object <- pdp::partial(my_rf, pred.var = c("Cylinder", "Mileage"))

form <- stats::as.formula(paste("yhat ~", paste(names(object)[1L:2L], 
                                                collapse = "*")))

wireframe(form, data = object, drape =T, zlab = "Price", scale = list(arrows = F))

yields

enter image description here

Upvotes: 4

Related Questions