Reputation: 392
Suppose I have a data set and I want to do a 4-fold cross validation using logistic regression. So there will be 4 different models. In R, I did the following:
ctrl <- trainControl(method = "repeatedcv", number = 4, savePredictions = TRUE)
mod_fit <- train(outcome ~., data=data1, method = "glm", family="binomial", trControl = ctrl)
I would assume that mod_fit
should contain 4 separate sets of coefficients? When I type modfit$finalModel$
I just get the same set of coefficients.
Upvotes: 0
Views: 13493
Reputation: 169
@Damien your mod_fit
will not contain 4 separate set of coefficients. You are asking for cross validation
with 4
folds. This does not mean you will have 4 different models. According to the documentation here, the train
function works as follows:
At the end of the resampling loop - in your case 4 iterations for 4 folds, you will have one set of average forecast accuracy measures (eg., rmse, R-squared), for a given one set of model parameters.
Since you did not use tuneGrid
or tuneLength
argument in train
function, by default, train
function will tune over three values of each tuneable parameter.
This means you will have at most three models (not 4 models as you were expecting) and therefore three sets of average model performance measures.
The optimum model is the one that has the lowest rmse in case of regression. This model coefficients are available in mod_fit$finalModel
.
Upvotes: 1
Reputation: 23200
I've created a reproducible example based on your code snippet. The first thing to notice about your code is that it's specifying repeatedcv
as the method, but it doesn't give any repeats
, so the number=4
parmeter is just telling it to resample 4 times (this is not an answer to your question but important to understand).
mod_fit$finalModel
gives you only 1 set of coefficients because it's the one final model that's derived by aggergating the non-repeated k-fold CV results from each of the 4 folds.
You can see the fold-level performance in the resample
object:
library(caret)
library(mlbench)
data(iris)
iris$binary <- ifelse(iris$Species=="setosa",1,0)
iris$Species <- NULL
ctrl <- trainControl(method = "repeatedcv",
number = 4,
savePredictions = TRUE,
verboseIter = T,
returnResamp = "all")
mod_fit <- train(binary ~.,
data=iris,
method = "glm",
family="binomial",
trControl = ctrl)
# Fold-level Performance
mod_fit$resample
RMSE Rsquared parameter Resample 1 2.630866e-03 0.9999658 none Fold1.Rep1 2 3.863821e-08 1.0000000 none Fold2.Rep1 3 8.162472e-12 1.0000000 none Fold3.Rep1 4 2.559189e-13 1.0000000 none Fold4.Rep1
To your earlier point, the package is not going to save and display information on the coefficients of each fold. In addition the the performance information above, does however save the index
(list of in-sample rows), indexOut
(hold how rows), and random seeds for each fold, thus if you were so inclined it would be easy to reconstruct the intermediate models.
mod_fit$control$seeds
[[1]] [1] 169815 [[2]] [1] 445763 [[3]] [1] 871613 [[4]] [1] 706905 [[5]] [1] 89408
mod_fit$control$index
$Fold1 [1] 1 2 3 4 5 6 7 8 9 10 11 12 15 18 19 21 22 24 28 30 31 32 33 34 35 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 59 60 61 63 [45] 64 65 66 68 69 70 71 72 73 75 76 77 79 80 81 82 84 85 86 87 89 90 91 92 93 94 95 96 98 99 100 103 104 106 107 108 110 111 113 114 116 118 119 120 [89] 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 140 141 142 143 145 147 149 150
$Fold2 [1] 1 6 7 8 12 13 14 15 16 17 18 19 20 21 22 23 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 42
44 46 48 50 51 53 54 55 56 57 58 [45] 59 61 62 64 66 67 69 70 71 72 73 74 75 76 78 79 80 81 82 83 84 85 87 88 89 90 91 92 95 96 97 98 99 101 102 104 105 106 108 109 111 112 113 115 [89] 116 117 119 120 121 122 123 127 130 131 132 134 135 137 138 139 140 141 142 143 144 145 146 147 148
$Fold3 [1] 2 3 4 5 6 7 8 9 10 11 13 14 16 17 20 23 24 25 26 27 28 29 30 33 35 36 37 38 39 40 41 43 45
46 47 49 50 51 52 54 55 56 57 58 [45] 60 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 82 83 84 85 86 88 89 93 94 97 98 99 100 101 102 103 105 106 107 108 109 110 111 112 114 115 [89] 117 118 119 121 124 125 126 128 129 131 132 133 134 135 136 137 138 139 144 145 146 147 148 149 150
$Fold4 [1] 1 2 3 4 5 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 29 31 32 34 36 37 38 39 41
42 43 44 45 47 48 49 52 53 55 56 [45] 57 58 59 60 61 62 63 65 67 68 74 77 78 79 80 81 83 86 87 88 90 91 92 93 94 95 96 97 100 101 102 103 104 105 107 109 110 112 113 114 115 116 117 118 [89] 120 122 123 124 125 126 127 128 129 130 133 136 137 138 139 140 141 142 143 144 146 148 149 150
mod_fit$control$indexOut
$Resample1 [1] 13 14 16 17 20 23 25 26 27 29 36 37 38 39 55 56 57 58 62 67 74 78 83 88 97 101 102 105 109 112 115 117 137 138 139 144 146 148 $Resample2 [1] 2 3 4 5 9 10 11 24 41 43 45 47 49 52 60 63 65 68 77 86 93 94 100 103 107 110 114 118 124 125 126 128 129 133 136 149 150 $Resample3 [1] 1 12 15 18 19 21 22 31 32 34 42 44 48 53 59 61 79 80 81 87 90 91 92 95 96 104 113 116 120 122 123 127 130 140 141 142 143 $Resample4 [1] 6 7 8 28 30 33 35 40 46 50 51 54 64 66 69 70 71 72 73 75 76 82 84 85 89 98 99 106 108 111 119 121 131 132 134 135 145 147
Upvotes: 5