itsMeInMiami
itsMeInMiami

Reputation: 2659

How can I use the same crossvalidation sets in R caret and rsamples

I am trying to learn the tidymodels ecosystem by converting caret::train() code into tidymodels workflows. I am getting differences that I think are a biproduct of the resampling algorithms in caret vs. rsample. A colleague wrote a gist showing the differences in datasets with the same seed: https://gist.github.com/bradleyboehmke/7794b79a07afb443da11d930ff84bed7

You can see small differences here in simple models (that I think I coded to be the same):

library(caret)
library(tidyverse)
library(tidymodels)
data(ames)

set.seed(123)
(cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(method="cv", number = 10)
))

vs.

set.seed(123)
folds <- vfold_cv(ames, v = 10)

the_lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

the_rec <- 
  recipe(Sale_Price ~ Gr_Liv_Area, data = ames)

the_workflow <- 
  workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(the_lm_model) 


the_results <- 
  fit_resamples(the_workflow, folds)

collect_metrics(the_results)

Is there a straight forward way to use caret resamples (from caret::createFolds()) in a tidymodel workflow (that would normally be created with rsample::vfold_cv()? I am hoping if I can figure out this detail I can replicate complex old code with the new ecosystem (for teaching).

Upvotes: 1

Views: 246

Answers (1)

missuse
missuse

Reputation: 19716

Edit. Thanks to Julia Silge comment.

The functions rsample2caret() and caret2rsample()

can be used to convert resampling objects between formats.

The answer below can be useful to convert from arbitrary formats to rsample.

Old Answer

Here is an approach to convert the output of caret::createFolds to rsample

library(caret)
library(tidyverse)
library(tidymodels)

data(ames)

#create train folds
set.seed(123)
folds_train <- caret::createFolds(ames$Sale_Price, returnTrain = TRUE, k = 10)

#get test indexes
folds_test <- lapply(folds_train, function(x) setdiff(seq_along(ames$Sale_Price), x))

combine the train and test indexes to create a list of analysis and assessment lists as described in manual_rset

rsplit <- map2(folds_train,
               folds_test,
               function(x,y) list(analysis = x, assessment = y))

splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))
> splits
# Manual resampling 
# A tibble: 10 x 2
   splits             id    
   <named list>       <chr> 
 1 <split [2637/293]> Fold01
 2 <split [2638/292]> Fold02
 3 <split [2637/293]> Fold03
 4 <split [2637/293]> Fold04
 5 <split [2638/292]> Fold05
 6 <split [2637/293]> Fold06
 7 <split [2637/293]> Fold07
 8 <split [2636/294]> Fold08
 9 <split [2636/294]> Fold09
10 <split [2637/293]> Fold10

check to see if same result:

set.seed(123)
cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(index= folds_train))
> cv_model1
Linear Regression 

2930 samples
   1 predictor

No pre-processing
Resampling: Bootstrapped (10 reps) 
Summary of sample sizes: 2637, 2638, 2637, 2637, 2638, 2637, ... 
Resampling results:

  RMSE      Rsquared   MAE     
  56364.67  0.5066935  38575.21

Tuning parameter 'intercept' was held constant at a value of TRUE

the_lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

the_rec <- 
  recipe(Sale_Price ~ Gr_Liv_Area, data = ames)

the_workflow <- 
  workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(the_lm_model) 

set.seed(123)
the_results <- 
  fit_resamples(the_workflow, splits)

collect_metrics(the_results)
> collect_metrics(the_results)
# A tibble: 2 x 6
  .metric .estimator      mean     n   std_err .config             
  <chr>   <chr>          <dbl> <int>     <dbl> <chr>               
1 rmse    standard   56365.       10 1782.     Preprocessor1_Model1
2 rsq     standard       0.507    10    0.0220 Preprocessor1_Model1

all.equal(
cv_model1$results$RMSE,
collect_metrics(the_results)$mean[1])
TRUE

perhaps there is a more straightforward way but I don't use tidymodels to know for sure.

If you did not create folds prior to calling caret::train:

set.seed(123)
cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(number = 10, method = "cv"))

you can use

cv_model1$control$index

cv_model1$control$indexOut

to create a rsample object

rsplit <- map2(cv_model1$control$index,
               cv_model1$control$indexOut,
               function(x,y) list(analysis = x, assessment = y))

and proceed as described above.

splits <- lapply(rsplit, make_splits, data = ames)
splits <- manual_rset(splits, names(splits))

Upvotes: 1

Related Questions