Damien
Damien

Reputation: 392

The train function in R caret package

Suppose I have a data set and I want to do a 4-fold cross validation using logistic regression. So there will be 4 different models. In R, I did the following:

ctrl <- trainControl(method = "repeatedcv", number = 4, savePredictions = TRUE)
mod_fit <- train(outcome ~., data=data1, method = "glm", family="binomial", trControl = ctrl)

I would assume that mod_fit should contain 4 separate sets of coefficients? When I type modfit$finalModel$ I just get the same set of coefficients.

Upvotes: 0

Views: 13493

Answers (2)

Balki
Balki

Reputation: 169

@Damien your mod_fit will not contain 4 separate set of coefficients. You are asking for cross validation with 4 folds. This does not mean you will have 4 different models. According to the documentation here, the train function works as follows:

enter image description here

At the end of the resampling loop - in your case 4 iterations for 4 folds, you will have one set of average forecast accuracy measures (eg., rmse, R-squared), for a given one set of model parameters.

Since you did not use tuneGrid or tuneLength argument in train function, by default, train function will tune over three values of each tuneable parameter.

This means you will have at most three models (not 4 models as you were expecting) and therefore three sets of average model performance measures.

The optimum model is the one that has the lowest rmse in case of regression. This model coefficients are available in mod_fit$finalModel.

Upvotes: 1

Hack-R
Hack-R

Reputation: 23200

I've created a reproducible example based on your code snippet. The first thing to notice about your code is that it's specifying repeatedcv as the method, but it doesn't give any repeats, so the number=4 parmeter is just telling it to resample 4 times (this is not an answer to your question but important to understand).

mod_fit$finalModel gives you only 1 set of coefficients because it's the one final model that's derived by aggergating the non-repeated k-fold CV results from each of the 4 folds.

You can see the fold-level performance in the resample object:

library(caret)
library(mlbench)

data(iris)

iris$binary  <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL

ctrl    <- trainControl(method = "repeatedcv", 
                        number = 4, 
                        savePredictions = TRUE,
                        verboseIter = T,
                        returnResamp = "all")

mod_fit <- train(binary ~., 
                 data=iris, 
                 method = "glm", 
                 family="binomial", 
                 trControl = ctrl)


# Fold-level Performance
mod_fit$resample
          RMSE  Rsquared parameter   Resample
1 2.630866e-03 0.9999658      none Fold1.Rep1
2 3.863821e-08 1.0000000      none Fold2.Rep1
3 8.162472e-12 1.0000000      none Fold3.Rep1
4 2.559189e-13 1.0000000      none Fold4.Rep1

To your earlier point, the package is not going to save and display information on the coefficients of each fold. In addition the the performance information above, does however save the index (list of in-sample rows), indexOut (hold how rows), and random seeds for each fold, thus if you were so inclined it would be easy to reconstruct the intermediate models.

mod_fit$control$seeds
[[1]]
[1] 169815

[[2]]
[1] 445763

[[3]]
[1] 871613

[[4]]
[1] 706905

[[5]]
[1] 89408
mod_fit$control$index
$Fold1
  [1]   1   2   3   4   5   6   7   8   9  10  11  12  15  18  19  21  22  24  28  30  31  32  33  34  35  40  41  42  43  44  45  46  47

48 49 50 51 52 53 54 59 60 61 63 [45] 64 65 66 68 69 70 71 72 73 75 76 77 79 80 81 82 84 85 86 87 89 90 91 92 93 94 95 96 98 99 100 103 104 106 107 108 110 111 113 114 116 118 119 120 [89] 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 140 141 142 143 145 147 149 150

$Fold2
  [1]   1   6   7   8  12  13  14  15  16  17  18  19  20  21  22  23  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  42

44 46 48 50 51 53 54 55 56 57 58 [45] 59 61 62 64 66 67 69 70 71 72 73 74 75 76 78 79 80 81 82 83 84 85 87 88 89 90 91 92 95 96 97 98 99 101 102 104 105 106 108 109 111 112 113 115 [89] 116 117 119 120 121 122 123 127 130 131 132 134 135 137 138 139 140 141 142 143 144 145 146 147 148

$Fold3
  [1]   2   3   4   5   6   7   8   9  10  11  13  14  16  17  20  23  24  25  26  27  28  29  30  33  35  36  37  38  39  40  41  43  45

46 47 49 50 51 52 54 55 56 57 58 [45] 60 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 82 83 84 85 86 88 89 93 94 97 98 99 100 101 102 103 105 106 107 108 109 110 111 112 114 115 [89] 117 118 119 121 124 125 126 128 129 131 132 133 134 135 136 137 138 139 144 145 146 147 148 149 150

$Fold4
  [1]   1   2   3   4   5   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  29  31  32  34  36  37  38  39  41

42 43 44 45 47 48 49 52 53 55 56 [45] 57 58 59 60 61 62 63 65 67 68 74 77 78 79 80 81 83 86 87 88 90 91 92 93 94 95 96 97 100 101 102 103 104 105 107 109 110 112 113 114 115 116 117 118 [89] 120 122 123 124 125 126 127 128 129 130 133 136 137 138 139 140 141 142 143 144 146 148 149 150

mod_fit$control$indexOut
$Resample1
 [1]  13  14  16  17  20  23  25  26  27  29  36  37  38  39  55  56  57  58  62  67  74  78  83  88  97 101 102 105 109 112 115 117 137
138 139 144 146 148

$Resample2
 [1]   2   3   4   5   9  10  11  24  41  43  45  47  49  52  60  63  65  68  77  86  93  94 100 103 107 110 114 118 124 125 126 128 129
133 136 149 150

$Resample3
 [1]   1  12  15  18  19  21  22  31  32  34  42  44  48  53  59  61  79  80  81  87  90  91  92  95  96 104 113 116 120 122 123 127 130
140 141 142 143

$Resample4
 [1]   6   7   8  28  30  33  35  40  46  50  51  54  64  66  69  70  71  72  73  75  76  82  84  85  89  98  99 106 108 111 119 121 131
132 134 135 145 147

Upvotes: 5

Related Questions