Amelio Vazquez-Reina
Amelio Vazquez-Reina

Reputation: 96448

Statistics of prediction for multiple models with caret

I am trying to get statistics of prediction for various training models with the package caret. Below is an example that illustrates my need:

library(caret)

# Training:
# ... Get X and Y for training a binary classification problem. 
# ... X is input (2000, 5) Y is output (2000,1) ... 

tmp <- createDataPartition(Y, p = 3/4, times = 3, list = TRUE, groups = min(5, length(Y)))

myCtrl <- trainControl(method = "boot", index = tmp, timingSamps = 2, classProbs = TRUE, summaryFunction = twoClassSummary)

RFmodel <- train(X,Y,method='rf',trControl=myCtrl,tuneLength=1, metric="ROC")
SVMmodel <- train(X,Y,method='svmRadial',trControl=myCtrl,tuneLength=3, metric="ROC")
KNNmodel <- train(X,Y,method='knn',trControl=myCtrl,tuneLength=10, metric="ROC")
NNmodel <- train(X,Y,method='nnet',trControl=myCtrl,tuneLength=3, trace = FALSE, metric="ROC")

# resamps reports ROC, Sens, Spec for all models
resamps <- resamples(list(RF = RFmodel, KNN = KNNmodel, NN = NNmodel, SVM = SVMmodel))

# Prediction:
# ... Collect X_pred (7000, 5) and Y_pred  (7000,1) ... 
testPred <- predict(list(RF = RFmodel, KNN = KNNmodel, NN = NNmodel, SVM = SVMmodel), Xtst, type="prob")

How can I get the statistics of prediction (ROC, etc.) from X_kand Y_pred for my 4 models?

Upvotes: 0

Views: 1667

Answers (1)

Jakub Langr
Jakub Langr

Reputation: 626

#Make a list of all the models
all.models <- list(model1, model2, model3, model4, model5, model6)
names(all.models) <- sapply(all.models, function(x) x$method)
sort(sapply(all.models, function(x) min(x$results$RMSE)))

The above bit of code is not mine, if I recall correctly.

# Table

# CORRELATIONS 
correlations = c(
cor(predict(model1,newdata=TD),Y),
cor(predict(model2,newdata=TD),Y),
cor(predict(model3,newdata=TD),Y),
cor(predict(model4,newdata=TD),Y),
cor(predict(model5,newdata=TD),Y),

RMSE = as.numeric(sapply(all.models, function(x) min(x$results$RMSE)))

names=c('General Linear Model','Random Forests','Artificial Neural Networks','Logistic/multinomial regression','K nearest neighbors', 'Support Vector Machines')

matrix(c(names,correlations,RMSE),ncol=3)

Hope this helps. I know it's not ROC, but these are some statistics of a prediction.

Upvotes: 1

Related Questions