Reputation: 1
I do resamples_fit
with work_flow()
and V-Fold Cross-Validation
.
My model is logistic regression.
How can I get coefficients of a parsnip logistic regression model with V-Fold Cross-Validation?
If my V-Fold Cross-Validation v=5
, I want to get the 5 times coefficients.
Upvotes: 0
Views: 293
Reputation: 11633
You typically do not want to use fit_resamples()
to train and keep five models; the main purpose of the fit_resamples()
function is to use resampling to estimate performance. The five models are fit and then thrown away.
However, if you do have some use case where you want to keep around the models that are fit, such as in this article, then you would use extract_model
.
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data(penguins)
set.seed(2021)
penguin_split <- penguins %>%
filter(!is.na(sex)) %>%
initial_split(strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
penguin_folds <- vfold_cv(penguin_train, v = 5, strata = sex)
penguin_folds
#> # 5-fold cross-validation using stratification
#> # A tibble: 5 x 2
#> splits id
#> <list> <chr>
#> 1 <split [198/51]> Fold1
#> 2 <split [199/50]> Fold2
#> 3 <split [199/50]> Fold3
#> 4 <split [200/49]> Fold4
#> 5 <split [200/49]> Fold5
glm_spec <- logistic_reg() %>%
set_engine("glm")
glm_rs <- workflow() %>%
add_formula(sex ~ species + bill_length_mm + bill_depth_mm + body_mass_g) %>%
add_model(glm_spec) %>%
fit_resamples(
resamples = penguin_folds,
control = control_resamples(extract = extract_model, save_pred = TRUE)
)
Now that you have used extract_model
in your resampling, it is there in your results and you have the models available for each fold.
glm_rs
#> # Resampling results
#> # 5-fold cross-validation using stratification
#> # A tibble: 5 x 6
#> splits id .metrics .notes .extracts .predictions
#> <list> <chr> <list> <list> <list> <list>
#> 1 <split [198/… Fold1 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [51 × …
#> 2 <split [199/… Fold2 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 3 <split [199/… Fold3 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 4 <split [200/… Fold4 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
#> 5 <split [200/… Fold5 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
glm_rs$.extracts[[1]]
#> # A tibble: 1 x 2
#> .extracts .config
#> <list> <chr>
#> 1 <glm> Preprocessor1_Model1
You can use tidyr and broom functions to get the coefficients out, if that is what you are looking for.
glm_rs %>%
dplyr::select(id, .extracts) %>%
unnest(cols = .extracts) %>%
mutate(tidied = map(.extracts, tidy)) %>%
unnest(tidied)
#> # A tibble: 30 x 8
#> id .extracts .config term estimate std.error statistic p.value
#> <chr> <list> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 Fold1 <glm> Preprocesso… (Interce… -7.44e+1 12.6 -5.89 3.75e-9
#> 2 Fold1 <glm> Preprocesso… speciesC… -6.59e+0 1.82 -3.61 3.03e-4
#> 3 Fold1 <glm> Preprocesso… speciesG… -7.49e+0 2.54 -2.95 3.18e-3
#> 4 Fold1 <glm> Preprocesso… bill_len… 5.56e-1 0.151 3.67 2.40e-4
#> 5 Fold1 <glm> Preprocesso… bill_dep… 1.72e+0 0.424 4.06 4.83e-5
#> 6 Fold1 <glm> Preprocesso… body_mas… 5.88e-3 0.00130 4.51 6.44e-6
#> 7 Fold2 <glm> Preprocesso… (Interce… -6.87e+1 11.3 -6.06 1.37e-9
#> 8 Fold2 <glm> Preprocesso… speciesC… -5.59e+0 1.75 -3.20 1.39e-3
#> 9 Fold2 <glm> Preprocesso… speciesG… -7.61e+0 2.80 -2.71 6.65e-3
#> 10 Fold2 <glm> Preprocesso… bill_len… 4.88e-1 0.145 3.36 7.88e-4
#> # … with 20 more rows
Created on 2021-06-27 by the reprex package (v2.0.0)
Upvotes: 4