Reputation: 401
I'm trying to generate interactive partial dependence plots by looping over the columns in the data set.
A reproducible example:
library(pdp)
library(xgboost)
library(Matrix)
library(ggplot2)
library(plotly)
data(mtcars)
target <- mtcars$mpg
mtcars$mpg <- NULL
mtcars.sparse <- sparse.model.matrix(target~., mtcars)
fit <- xgboost(data=mtcars.sparse, label=target, nrounds=100)
for (i in seq_along(names(mtcars))){
p1 <- pdp::partial(fit,
pred.var = names(mtcars)[i],
pred.grid = data.frame(unique(mtcars[names(mtcars)[i]])),
train = mtcars.sparse,
type = "regression",
cats = c("cyl", "vs", "am", "gear", "carb"),
plot = FALSE)
p2 <- ggplot(aes_string(x = names(mtcars)[i] , y = "yhat"), data = p1) +
geom_line(color = '#E51837', size = .6) +
labs(title = paste("Partial Dependence plot of", names(mtcars)[i] , sep = " ")) +
theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
plot.title = element_text(size = 13, color = '#333333'))
print(ggplotly(p2, tooltip = c("x", "y")))
}
The plotting loop on my real data set (~22k rows, 30 columns) takes around 2 hours. Any ideas on how to speed it up?
Upvotes: 3
Views: 859
Reputation: 2623
Due to the way data structures are used in R, for()
loops can be excruciatingly slow if you're not careful. If you want to know more about the technical reasons behind this, take a look at Advanced R by Hadley Wickham.
Practically, there are two main approaches to speeding up what you're looking to do: optimizing the for()
loop, and using the apply()
family of functions. While both approaches can work well, the apply()
method tends to be faster, even than an optimally written for()
loop, so I'll stick with that solution.
the apply
method:
plotFunction <-
function(x) {
p1 <- pdp::partial(fit,
pred.var = x,
pred.grid = data.frame(unique(mtcars[x])),
train = mtcars.sparse,
type = "regression",
cats = c("cyl", "vs", "am", "gear", "carb"),
plot = FALSE)
p2 <- ggplot(aes_string(x = x , y = "yhat"), data = p1) +
geom_line(color = '#E51837', size = .6) +
labs(title = paste("Partial Dependence plot of", x , sep = " ")) +
theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
plot.title = element_text(size = 13, color = '#333333'))
return(p2)
}
plot.list <- lapply(varNames, plotFunction)
system.time(lapply(varNames, plotFunction))
user system elapsed
0.471 0.004 0.488
Running the same benchmark on your for()
loop gave:
user system elapsed
3.945 0.616 3.519
As you'll notice, that's about a 10X improvement in speed, simply by pasting your loop code into a function, with minor modifications.
If you want additional speed, there are a few tweaks you can make to your function, but perhaps the most powerful aspect of the apply()
approach is that it lends itself well to parallelization, which can be done with packages like pbmcapply
implementing pbmcapply
gives you even more speed;
library(pdp)
library(xgboost)
library(Matrix)
library(ggplot2)
library(plotly)
library(pbmcapply)
# Determines the number of cores you want to use for paralell processing
# I like to leave two of mine available, but you can get away with 1
nCores <- detectCores() - 1
data(mtcars)
target <- mtcars$mpg
mtcars$mpg <- NULL
mtcars.sparse <- sparse.model.matrix(target~., mtcars)
fit <- xgboost(data=mtcars.sparse, label=target, nrounds=100)
varNames <-
names(mtcars) %>%
as.list
plotFunction <-
function(x) {
p1 <- pdp::partial(fit,
pred.var = x,
pred.grid = data.frame(unique(mtcars[x])),
train = mtcars.sparse,
type = "regression",
cats = c("cyl", "vs", "am", "gear", "carb"),
plot = FALSE)
p2 <- ggplot(aes_string(x = x , y = "yhat"), data = p1) +
geom_line(color = '#E51837', size = .6) +
labs(title = paste("Partial Dependence plot of", x , sep = " ")) +
theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
plot.title = element_text(size = 13, color = '#333333'))
return(p2)
}
plot.list <- pbmclapply(varNames, plotFunction, mc.cores = nCores)
Let's see how that did
user system elapsed
0.842 0.458 0.320
A small improvement over lapply()
, but that improvement should scale with your bigger dataset. Hope this helps!
Upvotes: 1