Reputation: 1
It seems to me that I’ve discovered a bug in the performance of the predict() function for method=gbm in the Caret package in R. I'm curious to know if others agree, or if someone has an explanation for the behavior of this function.
1. Generate data
library(caret)
x1 <- rnorm(100)
x2 <- rnorm(100, 2)
y <- x1 + x2 + rnorm(100)
df <- data.frame(x1=x1, x2=x2, y=y)
2. Predict using method="lm"
The following code works as expected: using method=“lm” the two predicted values match. In the first case, p1, “y” is included in newdata, in the second case, p2, it is not.
tempd <- df[1:99, c("y", "x1", "x2") ]
newdata <- df[100, c("y", "x1", "x2")]
lm.fit <- train(y~x1 + x2, data=tempd, method="lm")
p1 <- predict(lm.fit$finalModel, newdata=newdata)
newdata <- df[100, c("x1", "x2")]
p2 <- predict(lm.fit$finalModel, newdata=newdata)
p1 should equal p2, and does:
p1==p2
3. Predict using method="gbm"
This code does not work as expected: using method=“gbm,” with the identical set up, the two predicted values do not match.
tempd <- df[1:99, c("y","x1","x2")]
newdata <- df[100, c("y","x1","x2")]
gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)
p1 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
newdata <- df[100, c("x1","x2")]
p2 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
In this case, p1 does not equal p2:
p1==p2
4. Predict using method="gbm" with a different set up
BUT, curiously, with one small change—not explicitly naming the variables in the subset operation--it does work:
tempd <- df[1:99, ]
newdata <- df[100, ]
gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)
p1 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
newdata <- df[100, c("x1","x2")]
p2 <- predict(gbm.fit$finalModel, newdata=newdata,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
p1==p2
Thanks in advance for our thoughts.
Jeff
Upvotes: 0
Views: 298
Reputation: 206167
As @Pascal pointed out, you're skipping an important step. Rather than calling predict()
on the finalModel value, you should be calling predict
on the gmb.fit
object directly. Note
class(gbm.fit)
# [1] "train" "train.formula"
class(gbm.fit$finalModel)
# [1] "gbm"
Since these objects have different classes, they trigger different underlying prediction function. The important part is that predict.train
reshapes the newdata
to the correct format for the gbm
predictor. Without this data reshaping, you will get the incorrect results (the predictor expects the columns to be in a certain order)
Observe
newdata1 <- df[100, c("y","x1","x2")]
newdata2 <- df[100, c("x1","x2")]
newdata3 <- df[100, ]
predict(gbm.fit, newdata1)
# [1] 1.427069
predict(gbm.fit, newdata2)
# [1] 1.427069
predict(gbm.fit, newdata3)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata1,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 2.166468
predict(gbm.fit$finalModel, newdata=newdata2,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata3,
n.trees=gbm.fit$finalModel$tuneValue$n.trees,
interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
So if you're going to use the train()
function to fit your model, be sure to use the proper predict.train
function to correctly make predictions from the model.
Upvotes: 2