Mael Fosso
Mael Fosso

Reputation: 400

How to get the training error from fit_resamples and hyperparameter tuning?

During a cross-validation, fit_resamples return the average of the metric from the validation set.

lr_model <-
  linear_reg() |>
  set_engine('lm')

lr_wf <-
  workflow() |>
  add_recipe(basic_recipe) |>
  add_model(lr_model)

lr_cv <-
  lr_wf |>
  fit_resamples(
    folds,
    metrics = metric_set(rmse),
    control = control
  )
  
# let' extract result from CV. that will help us to compare it with other models
lr_cv |>
  collect_metrics()
# That's the RMSE validation error
# .metric .estimator  mean     n  std_err .config             
# <chr>   <chr>      <dbl> <int>    <dbl> <chr>               
# rmse    standard   0.161    10 0.000370 Preprocessor1_Model1

The issue I have is how to get the training error.

The same issue occurs after the tuning of hyperparameters.

For example, when tuning the KNN to find the best number of neighbors, the collect_metrics and show_best return the average of the metrics of the validation set from cross-validation, whereas we all know that the best number of neighbors is when while the training errors decreased the validation errors start increasing. Unfortunately, the autoplot function does not show us the training errors only the validation errors.

In this case, for example

tree_grid <-
  grid_regular(
    cost_complexity(),
    tree_depth(),
    min_n(),
    levels = c(3, 5, 10)
  )

tree_wf <-
  workflow() %>%
  add_model(tree_model) %>%
  add_recipe(basic_recipe)

tree_res <- 
  tree_wf %>%
  tune_grid(
    resamples = folds,
    grid = tree_grid,
    metrics = metric_set(rmse),
    control = control
  )

How to extract the training errors of each couple hyperparameters/folds?

Upvotes: 1

Views: 56

Answers (1)

topepo
topepo

Reputation: 14331

UPDATE: based on more specifics, see the code at the end:

tidymodels doesn’t have a method to automatically generate them because, for many models, they are highly problematic (e.g., unrealistically optimistic). That can be frustrating but, overall, more harm is done than good if those results were included.

You can do a little work to get them using the extract options for the control functions.

Here is a reproducible example:

library(tidymodels)
tidymodels_prefer()

set.seed(1)
reg_dat <- sim_regression(200)
reg_rs <- vfold_cv(reg_dat)

tree_grid <-
  grid_space_filling(
    cost_complexity(),
    tree_depth(),
    min_n(),
    size = 3)

tree_spec <-
  decision_tree(
    cost_complexity = tune(),
    tree_depth = tune(),
    min_n = tune()
  ) %>%
  set_mode("regression")

tree_wf <- workflow(outcome ~ ., tree_spec)

# Write a function to make predictions: 
get_train_rmse <- function(x) {
  require(tidymodels) # <- sometimes needed for parallel proc
  # 'x' is the fitted workflow
  augment(x, reg_dat) %>% 
    rmse(outcome, .pred)
}

# pass it in:
ctrl <- control_grid(extract = get_train_rmse)  
  
tree_res <- 
  tree_wf %>%
  tune_grid(
    resamples = reg_rs,
    grid = tree_grid,
    metrics = metric_set(rmse),
    control = ctrl
  )

names(tree_res)
#> [1] "splits"    "id"        ".metrics"  ".notes"    ".extracts"

collect_extracts(tree_res) %>% 
  unnest(.extracts) %>% 
  relocate(training_error = .estimate)
#> # A tibble: 30 × 8
#>    training_error id     cost_complexity tree_depth min_n .metric .estimator
#>             <dbl> <chr>            <dbl>      <int> <int> <chr>   <chr>     
#>  1           12.9 Fold01    0.0000000001          8    21 rmse    standard  
#>  2           18.0 Fold01    0.00000316            1     2 rmse    standard  
#>  3           16.2 Fold01    0.1                  15    40 rmse    standard  
#>  4           12.3 Fold02    0.0000000001          8    21 rmse    standard  
#>  5           18.1 Fold02    0.00000316            1     2 rmse    standard  
#>  6           16.1 Fold02    0.1                  15    40 rmse    standard  
#>  7           12.5 Fold03    0.0000000001          8    21 rmse    standard  
#>  8           18.0 Fold03    0.00000316            1     2 rmse    standard  
#>  9           16.4 Fold03    0.1                  15    40 rmse    standard  
#> 10           12.4 Fold04    0.0000000001          8    21 rmse    standard  
#> # ℹ 20 more rows
#> # ℹ 1 more variable: .config <chr>

Created on 2025-01-08 with reprex v2.1.0

Note that this may not automatically work for some parallel processing technologies since reg_dat is not in the worker processes. However, the parallel package has functions that can pass the data (or other objects) into each worker to make it available for computations.

UPDATE: based on this feedback:

Thank you, @topepo, for your reply, but I have a problem with that. You are reusing the whole training dataset instead of using the training part of the data from the current fold. How do I access the training part of the data of the current fold when running fit_resamples?

No problem. This is exactly why we don't use "training set" to describe the data used for modeling within resamples; we use the term analysis set to be more specific.

Here's another reprex:

library(tidymodels)
tidymodels_prefer()

set.seed(1)
reg_dat <- sim_regression(200)
reg_rs <- vfold_cv(reg_dat)

tree_grid <-
  grid_space_filling(
    cost_complexity(),
    tree_depth(),
    min_n(),
    size = 3)

tree_spec <-
  decision_tree(
    cost_complexity = tune(),
    tree_depth = tune(),
    min_n = tune()
  ) %>%
  set_mode("regression")

tree_wf <- workflow(outcome ~ ., tree_spec)

# Write a function to return the fitted workflow: 
get_fit <- function(x) {
  x
}

# pass it in:
ctrl <- control_grid(extract = get_fit)  
rmse_set <- metric_set(rmse)

tree_res <- 
  tree_wf %>%
  tune_grid(
    resamples = reg_rs,
    grid = tree_grid,
    metrics = rmse_set,
    control = ctrl
  ) 

tree_res %>% 
  select(splits, starts_with("id"), .extracts) %>% 
  # Expand the models fit within each of the resamples
  unnest(.extracts) %>% 
  mutate(
    # Predict each splits analysis set
    tr_pred = map2(splits, .extracts, ~ augment(.y, analysis(.x))),
    # Compute metrics
    tr_metrics = map(tr_pred, ~ rmse_set(.x, outcome, .pred))
  ) %>% 
  select(starts_with("id"), cost_complexity, tree_depth, min_n, tr_metrics) %>% 
  unnest(tr_metrics)
#> # A tibble: 30 × 7
#>    id     cost_complexity tree_depth min_n .metric .estimator .estimate
#>    <chr>            <dbl>      <int> <int> <chr>   <chr>          <dbl>
#>  1 Fold01    0.0000000001          8    21 rmse    standard        12.2
#>  2 Fold01    0.00000316            1     2 rmse    standard        17.1
#>  3 Fold01    0.1                  15    40 rmse    standard        15.6
#>  4 Fold02    0.0000000001          8    21 rmse    standard        11.3
#>  5 Fold02    0.00000316            1     2 rmse    standard        17.7
#>  6 Fold02    0.1                  15    40 rmse    standard        15.7
#>  7 Fold03    0.0000000001          8    21 rmse    standard        11.7
#>  8 Fold03    0.00000316            1     2 rmse    standard        17.5
#>  9 Fold03    0.1                  15    40 rmse    standard        16.1
#> 10 Fold04    0.0000000001          8    21 rmse    standard        12.3
#> # ℹ 20 more rows

Created on 2025-01-10 with reprex v2.1.1

Upvotes: 1

Related Questions