Salivan
Salivan

Reputation: 167

How to simulate last_fit() using fit() in tidymodels?

I would like to apply the random forest method to predict the wait time of patients at a hospital. I closely follow the instructions at https://www.tidymodels.org/start/case-study/ to tune my model. After obtaining my best model, I create the last_rf_workflow object as described in the link above.

last_rf_workflow <- 
  Data_rf_wflow %>% 
  update_model(last_rf_mod)

Then, I use the code below to fit the final model:

set.seed(345)
last_rf_fit <- 
  last_rf_workflow %>% 
  last_fit(data_split)

As described here the last_fit() function fits a model on the entire training set and calculates the predicted values for the testing set. These predictions can be accessed at collect_predictions(last_rf_fit).

However, when I fit the model to the entire training set and then use the predict() function, I get slightly different predictions:

set.seed(345)
last_rf_fit_2 <- 
  last_rf_workflow %>% 
  fit(training(data_split))

predict(last_rf_fit_2, testing(data_split))

I wonder if someone could help me understand why these two predictions are different. Thanks.

Upvotes: 1

Views: 2182

Answers (1)

hnagaty
hnagaty

Reputation: 848

I think it's because ranger needs its own separate seed. I used set_engine("ranger", seed = 123).
I reproduced the output, building on Julia's reproducable code, and the two predictions are the same.

library(tidymodels)

set.seed(123)
tr_te_split <- initial_split(mtcars)

rf_spec <- rand_forest() %>%
  set_mode("regression") %>%
  set_engine("ranger", seed = 123)

rf_wf <- workflow() %>%
  add_model(rf_spec) %>%
  add_formula(mpg ~ .)

set.seed(345)
last_rf_fit <- last_fit(rf_wf, split = tr_te_split)
collect_predictions(last_rf_fit)
#> # A tibble: 8 x 4
#>   id               .pred  .row   mpg
#>   <chr>            <dbl> <int> <dbl>
#> 1 train/test split  24.8     3  22.8
#> 2 train/test split  18.8    10  19.2
#> 3 train/test split  16.5    14  15.2
#> 4 train/test split  13.6    15  10.4
#> 5 train/test split  28.2    18  32.4
#> 6 train/test split  29.2    19  30.4
#> 7 train/test split  17.3    22  15.5
#> 8 train/test split  15.3    31  15


set.seed(345)
last_rf_fit_2 <- fit(rf_wf, training(tr_te_split))
predict(last_rf_fit_2, testing(tr_te_split))
#> # A tibble: 8 x 1
#>   .pred
#>   <dbl>
#> 1  24.8
#> 2  18.8
#> 3  16.5
#> 4  13.6
#> 5  28.2
#> 6  29.2
#> 7  17.3
#> 8  15.3

Created on 2020-08-19 by the reprex package (v0.3.0)

Upvotes: 4

Related Questions