Jeff W
Jeff W

Reputation: 1

Possible bug in Caret function predict.gb()?

It seems to me that I’ve discovered a bug in the performance of the predict() function for method=gbm in the Caret package in R. I'm curious to know if others agree, or if someone has an explanation for the behavior of this function.

1. Generate data

library(caret)

x1 <- rnorm(100)

x2 <- rnorm(100, 2)

y <- x1 + x2 + rnorm(100)

df <- data.frame(x1=x1, x2=x2,  y=y)

2. Predict using method="lm"

The following code works as expected: using method=“lm” the two predicted values match. In the first case, p1, “y” is included in newdata, in the second case, p2, it is not.

tempd <- df[1:99, c("y", "x1", "x2") ]

newdata <- df[100, c("y", "x1", "x2")]

lm.fit <- train(y~x1 + x2, data=tempd, method="lm")

p1 <- predict(lm.fit$finalModel, newdata=newdata)

newdata <- df[100, c("x1", "x2")]

p2 <- predict(lm.fit$finalModel, newdata=newdata)

p1 should equal p2, and does:

p1==p2

3. Predict using method="gbm"

This code does not work as expected: using method=“gbm,” with the identical set up, the two predicted values do not match.

tempd <- df[1:99, c("y","x1","x2")]

newdata <- df[100, c("y","x1","x2")]

gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)

p1 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                       
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

newdata <- df[100, c("x1","x2")]

p2 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

In this case, p1 does not equal p2:

p1==p2

4. Predict using method="gbm" with a different set up

BUT, curiously, with one small change—not explicitly naming the variables in the subset operation--it does work:

tempd <- df[1:99, ]

newdata <- df[100, ]

gbm.fit <- train(y~x1+x2 , data=tempd, method="gbm", verbose=F)

p1 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                                         
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

newdata <- df[100, c("x1","x2")]

p2 <- predict(gbm.fit$finalModel, newdata=newdata,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)

p1==p2

Thanks in advance for our thoughts.

Jeff

Upvotes: 0

Views: 298

Answers (1)

MrFlick
MrFlick

Reputation: 206167

As @Pascal pointed out, you're skipping an important step. Rather than calling predict() on the finalModel value, you should be calling predict on the gmb.fit object directly. Note

class(gbm.fit)
# [1] "train"         "train.formula"
class(gbm.fit$finalModel)
# [1] "gbm"

Since these objects have different classes, they trigger different underlying prediction function. The important part is that predict.train reshapes the newdata to the correct format for the gbm predictor. Without this data reshaping, you will get the incorrect results (the predictor expects the columns to be in a certain order)

Observe

newdata1 <- df[100, c("y","x1","x2")]
newdata2 <- df[100, c("x1","x2")]
newdata3 <- df[100, ]

predict(gbm.fit, newdata1)
# [1] 1.427069
predict(gbm.fit, newdata2)
# [1] 1.427069
predict(gbm.fit, newdata3)
# [1] 1.427069

predict(gbm.fit$finalModel, newdata=newdata1,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 2.166468
predict(gbm.fit$finalModel, newdata=newdata2,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069
predict(gbm.fit$finalModel, newdata=newdata3,
          n.trees=gbm.fit$finalModel$tuneValue$n.trees,                  
          interaction.depth=gbm.fit$finalModel$tuneValue$interaction.depth,
          shrinkage=gbm.fit$finalModel$tuneValue$shrinkage)
# [1] 1.427069

So if you're going to use the train() function to fit your model, be sure to use the proper predict.train function to correctly make predictions from the model.

Upvotes: 2

Related Questions