Nils Kappelmann
Nils Kappelmann

Reputation: 21

R: caret package predict.train leading to nonsensical predictions

I'm running into a problem with the caret package predict.train function, which gives me somewhat random predictions for my "positive control" models (i.e., it's supposed to give relatively good predictions). The problem persisted across different algorithms ("glmnet" and "rf"). Yet, when I compute predictions manually using final glmnet model coefficients, predictions work well.

Here, an example of what I did:

Model Setup:

## Define fitControl object
fitControl = trainControl(method = "cv",
                          number = 5)
   
## Define tuneGrid
glmnet.tuneGrid = expand.grid(alpha = seq(from = 0, to = 1, by = 0.2),
                              lambda = seq(from = 0, to = 1, by = 0.2))

## Run inner CV
glmnet.fit = train(x = train[,x], y = train[,y], 
                   method = "glmnet", metric = "RMSE", 
                   trControl = fitControl,
                   tuneGrid = glmnet.tuneGrid)

Extraction of predictions using predict.train function and manual computation using best model coefficients:

## Predict in itest set
glmnet.preds = predict(glmnet.fit, newdata = test)
            
## Compute manual predictions
glmnet.coefs = coef(glmnet.fit$finalModel, s = glmnet.fit$bestTune$lambda)
manual.preds = as.vector(
glmnet.coefs[1,] + 
glmnet.coefs[2,]*test$t0_bdi_std + 
glmnet.coefs[3,]*test$sex_std + 
glmnet.coefs[4,]*test$age_std + 
glmnet.coefs[5,]*test$BMI_std)

If I evaluate the predictions, I get different values:

> glmnet.preds
       3        6        7       17       20       23       27       37       38       47       54 
21.07649 18.32825 18.30302 19.02607 21.18579 21.91725 18.84951 21.46324 18.64773 21.30349 22.01814 
      56       66       67       69       74       77       88       89       92       98      104 
21.52209 21.44642 18.65614 21.18579 19.54734 19.67345 21.86680 20.96719 18.79066 21.03445 20.81586 
     108 
19.19422 
> manual.preds
 [1] 20.97291 17.80435 17.77912 15.64083 13.21352 17.52165 20.47162 18.85598 22.05817 15.11957 21.91455
[12] 18.55717 16.69316 17.05924 25.01654 27.60745 22.36856 18.54421 16.21393 19.69743 18.06953 18.56627
[23] 23.32000

I also obtain differences in fit indices and those for the predict.train output seem relatively random while manual prediction works well as expected:

> postResample(test[, y], glmnet.preds)
        RMSE     Rsquared          MAE 
13.665491040  0.004892648 11.756136481 
> postResample(test[, y], manual.preds)
      RMSE   Rsquared        MAE 
11.7743854  0.4606725 10.0398907 

An interesting side note: I tried to create a fully reproducible example with simulated data, but then predict.train led to the same results.

I'd be really curious (and immensely grateful) if anyone knows what leads to these results and how to fix it.

System info:

Upvotes: 2

Views: 816

Answers (1)

StupidWolf
StupidWolf

Reputation: 46978

Since data is not provided, no way to know if the calculation is correct. Below I use an example dataset:

library(mlbench)
library(caret)

data(BostonHousing)
#exclude one factor column
tr_dat = BostonHousing[1:300,-4]
test_dat = BostonHousing[301:nrow(BostonHousing),-4]

fitControl = trainControl(method = "cv",number = 5)
   
glmnet.tuneGrid = expand.grid(alpha = seq(from = 0, to = 1, by = 0.2),
                              lambda = seq(from = 0, to = 1, by = 0.2))

glmnet.fit = train(x = tr_dat[,-ncol(tr_dat)], y = tr_dat[,ncol(tr_dat)], 
                   method = "glmnet",etric = "RMSE",trControl = fitControl,tuneGrid = glmnet.tuneGrid)

Caret prediction:

pred_caret = predict(glmnet.fit,newdata=test_dat)

We do the manual prediction, so you can get it by do a matrix multiplication %*% between your coefficients and predictor matrix:

predictor = cbind(Intercept=1,as.matrix(test_dat[,-ncol(test_dat)]))
coef_m = as.matrix(coef(glmnet.fit$finalModel,s=glmnet.fit$bestTune$lambda))
pred_manual = predictor %*% coef_m

table(pred_manual == pred_caret)

TRUE 
 206 

You get back exactly the same

Upvotes: 1

Related Questions