Ken
Ken

Reputation: 1

How can I get coefficients of a parsnip multinomial logistic regression model with vfold_cv?

I do resamples_fit with work_flow() and V-Fold Cross-Validation. My model is logistic regression.

How can I get coefficients of a parsnip logistic regression model with V-Fold Cross-Validation?

If my V-Fold Cross-Validation v=5, I want to get the 5 times coefficients.

Upvotes: 0

Views: 293

Answers (1)

Julia Silge
Julia Silge

Reputation: 11633

You typically do not want to use fit_resamples() to train and keep five models; the main purpose of the fit_resamples() function is to use resampling to estimate performance. The five models are fit and then thrown away.

However, if you do have some use case where you want to keep around the models that are fit, such as in this article, then you would use extract_model.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
data(penguins)

set.seed(2021)
penguin_split <- penguins %>%
  filter(!is.na(sex)) %>%
  initial_split(strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)

penguin_folds <- vfold_cv(penguin_train, v = 5, strata = sex)
penguin_folds
#> #  5-fold cross-validation using stratification 
#> # A tibble: 5 x 2
#>   splits           id   
#>   <list>           <chr>
#> 1 <split [198/51]> Fold1
#> 2 <split [199/50]> Fold2
#> 3 <split [199/50]> Fold3
#> 4 <split [200/49]> Fold4
#> 5 <split [200/49]> Fold5

glm_spec <- logistic_reg() %>%
  set_engine("glm") 

glm_rs <- workflow() %>%
  add_formula(sex ~ species + bill_length_mm + bill_depth_mm + body_mass_g) %>%
  add_model(glm_spec) %>%
  fit_resamples(
    resamples = penguin_folds,
    control = control_resamples(extract = extract_model, save_pred = TRUE)
  )

Now that you have used extract_model in your resampling, it is there in your results and you have the models available for each fold.

glm_rs
#> # Resampling results
#> # 5-fold cross-validation using stratification 
#> # A tibble: 5 x 6
#>   splits        id    .metrics       .notes        .extracts     .predictions   
#>   <list>        <chr> <list>         <list>        <list>        <list>         
#> 1 <split [198/… Fold1 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [51 × …
#> 2 <split [199/… Fold2 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 3 <split [199/… Fold3 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 4 <split [200/… Fold4 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
#> 5 <split [200/… Fold5 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …

glm_rs$.extracts[[1]]
#> # A tibble: 1 x 2
#>   .extracts .config             
#>   <list>    <chr>               
#> 1 <glm>     Preprocessor1_Model1

You can use and functions to get the coefficients out, if that is what you are looking for.

glm_rs %>% 
  dplyr::select(id, .extracts) %>%
  unnest(cols = .extracts) %>%
  mutate(tidied = map(.extracts, tidy)) %>%
  unnest(tidied)
#> # A tibble: 30 x 8
#>    id    .extracts .config      term      estimate std.error statistic   p.value
#>    <chr> <list>    <chr>        <chr>        <dbl>     <dbl>     <dbl>     <dbl>
#>  1 Fold1 <glm>     Preprocesso… (Interce… -7.44e+1  12.6         -5.89   3.75e-9
#>  2 Fold1 <glm>     Preprocesso… speciesC… -6.59e+0   1.82        -3.61   3.03e-4
#>  3 Fold1 <glm>     Preprocesso… speciesG… -7.49e+0   2.54        -2.95   3.18e-3
#>  4 Fold1 <glm>     Preprocesso… bill_len…  5.56e-1   0.151        3.67   2.40e-4
#>  5 Fold1 <glm>     Preprocesso… bill_dep…  1.72e+0   0.424        4.06   4.83e-5
#>  6 Fold1 <glm>     Preprocesso… body_mas…  5.88e-3   0.00130      4.51   6.44e-6
#>  7 Fold2 <glm>     Preprocesso… (Interce… -6.87e+1  11.3         -6.06   1.37e-9
#>  8 Fold2 <glm>     Preprocesso… speciesC… -5.59e+0   1.75        -3.20   1.39e-3
#>  9 Fold2 <glm>     Preprocesso… speciesG… -7.61e+0   2.80        -2.71   6.65e-3
#> 10 Fold2 <glm>     Preprocesso… bill_len…  4.88e-1   0.145        3.36   7.88e-4
#> # … with 20 more rows

Created on 2021-06-27 by the reprex package (v2.0.0)

Upvotes: 4

Related Questions