Reputation: 2659
I am trying to learn the tidymodels
ecosystem by converting caret::train()
code into tidymodels
workflows. I am getting differences that I think are a biproduct of the resampling algorithms in caret
vs. rsample
. A colleague wrote a gist showing the differences in datasets with the same seed: https://gist.github.com/bradleyboehmke/7794b79a07afb443da11d930ff84bed7
You can see small differences here in simple models (that I think I coded to be the same):
library(caret)
library(tidyverse)
library(tidymodels)
data(ames)
set.seed(123)
(cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(method="cv", number = 10)
))
vs.
set.seed(123)
folds <- vfold_cv(ames, v = 10)
the_lm_model <-
linear_reg() %>%
set_engine("lm")
the_rec <-
recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
the_workflow <-
workflow() %>%
add_recipe(the_rec) %>%
add_model(the_lm_model)
the_results <-
fit_resamples(the_workflow, folds)
collect_metrics(the_results)
Is there a straight forward way to use caret
resamples (from caret::createFolds()
) in a tidymodel
workflow (that would normally be created with rsample::vfold_cv()
? I am hoping if I can figure out this detail I can replicate complex old code with the new ecosystem (for teaching).
Upvotes: 1
Views: 246
Reputation: 19716
Edit. Thanks to Julia Silge comment.
The functions rsample2caret() and caret2rsample()
can be used to convert resampling objects between formats.
The answer below can be useful to convert from arbitrary formats to rsample.
Old Answer
Here is an approach to convert the output of caret::createFolds
to rsample
library(caret)
library(tidyverse)
library(tidymodels)
data(ames)
#create train folds
set.seed(123)
folds_train <- caret::createFolds(ames$Sale_Price, returnTrain = TRUE, k = 10)
#get test indexes
folds_test <- lapply(folds_train, function(x) setdiff(seq_along(ames$Sale_Price), x))
combine the train and test indexes to create a list of analysis and assessment lists as described in manual_rset
rsplit <- map2(folds_train,
folds_test,
function(x,y) list(analysis = x, assessment = y))
splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
> splits
# Manual resampling
# A tibble: 10 x 2
splits id
<named list> <chr>
1 <split [2637/293]> Fold01
2 <split [2638/292]> Fold02
3 <split [2637/293]> Fold03
4 <split [2637/293]> Fold04
5 <split [2638/292]> Fold05
6 <split [2637/293]> Fold06
7 <split [2637/293]> Fold07
8 <split [2636/294]> Fold08
9 <split [2636/294]> Fold09
10 <split [2637/293]> Fold10
check to see if same result:
set.seed(123)
cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(index= folds_train))
> cv_model1
Linear Regression
2930 samples
1 predictor
No pre-processing
Resampling: Bootstrapped (10 reps)
Summary of sample sizes: 2637, 2638, 2637, 2637, 2638, 2637, ...
Resampling results:
RMSE Rsquared MAE
56364.67 0.5066935 38575.21
Tuning parameter 'intercept' was held constant at a value of TRUE
the_lm_model <-
linear_reg() %>%
set_engine("lm")
the_rec <-
recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
the_workflow <-
workflow() %>%
add_recipe(the_rec) %>%
add_model(the_lm_model)
set.seed(123)
the_results <-
fit_resamples(the_workflow, splits)
collect_metrics(the_results)
> collect_metrics(the_results)
# A tibble: 2 x 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 rmse standard 56365. 10 1782. Preprocessor1_Model1
2 rsq standard 0.507 10 0.0220 Preprocessor1_Model1
all.equal(
cv_model1$results$RMSE,
collect_metrics(the_results)$mean[1])
TRUE
perhaps there is a more straightforward way but I don't use tidymodels to know for sure.
If you did not create folds prior to calling caret::train
:
set.seed(123)
cv_model1 <- train(
form = Sale_Price ~ Gr_Liv_Area,
data = ames,
method = "lm",
trControl = trainControl(number = 10, method = "cv"))
you can use
cv_model1$control$index
cv_model1$control$indexOut
to create a rsample
object
rsplit <- map2(cv_model1$control$index,
cv_model1$control$indexOut,
function(x,y) list(analysis = x, assessment = y))
and proceed as described above.
splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
Upvotes: 1