Rafael Santamaria
Rafael Santamaria

Reputation: 31

Logistic regression with tidymodels: How to set event_level ="second" in last_fit()?

I am building a logistic regression model with an outcome variable with 2 categories: a_category / z_category, and I have the following questions:

  1. I am interested in predicting "z_category" using the independent variables, therefore my reference category should be "a_category". "a_category" is the first category in the variable, therefore it´s not necessary to relevel my outcome category and this could be the code:

Splits:

splits<- initial_split(df1, strata = c(outcome), prop = 3/4)
training_set <- training(splits)
test_set  <- testing(splits)

Recipe:

      glm_rec <-
      recipe(outcome~., data=training_set) %>% 
      step_zv(all_predictors()) %>% 
      step_normalize(all_predictors()) %>% 
      step_dummy(all_nominal(), -all_outcomes())

Model spec:

glm_spec <- 
  logistic_reg() %>% 
  set_engine("glm") 

Workflow:

glm_final_wf <- 
  workflow() %>% 
  add_model(glm_spec) %>% 
  add_recipe(glm_rec)

Am I right?

  1. Internal validation and roc curves: I am using event_level = "second" to calculate metrics and roc curve using yardstick functions:
# metrics
glm_internalval_res <- glm_final_wf %>% 
  fit_resamples(
    resamples = vfold_cv(training_set, 
                                  v= 10, 
                                  repeats = 2, 
                                  strata = outcome),
    control = control_resamples(save_pred = TRUE, event_level = "second"),
    metrics = metric_set(
      yardstick::roc_auc, 
      yardstick::accuracy,
      yardstick::sens, 
      yardstick::spec,
      yardstick::precision, 
      yardstick::ppv,
      yardstick::npv)
      )

# ROC curve
glm_internalval_res %>%
  collect_predictions()%>%
  group_by(id, id2) %>%
  roc_curve(truth=outcome, 
            .pred_z_category,
            event_level = "second"
            ) %>%
   autoplot()

Am I right?

  1. External validation, last_fit. I cannot find how to set event_level="second". When I try:
glm_externalval_res <- 
  last_fit(glm_final_wf, 
           splits,
           metrics = metric_set(yardstick::roc_auc, 
      yardstick::accuracy,
      yardstick::sens, 
      yardstick::spec,
      yardstick::precision, 
      yardstick::ppv,
      yardstick::npv)
  )

Using this chunk, the metrics are referred to the first category "a_category", and I think this is not correct.

I am wondering how to indicate to last_fit that my category of interest is "z_category". I coudn´t find an answer in the package information.

Thanks.

Rafael.

Upvotes: 3

Views: 1005

Answers (2)

Ryan John
Ryan John

Reputation: 1430

One options is to set the global option for the 2nd event:

Pre 0.0.7

options(yardstick.event_first = FALSE)

Post 0.0.7:

options(yardstick.event_level = 'second')

Upvotes: 1

Julia Silge
Julia Silge

Reputation: 11623

The easiest thing to do is definitely to rename your levels so the one that you are interested in is first. However, if that is not what you want to do, then you need to make a metric with an option and put it into a metric_set(). The procedure for this is outlined in the docs for metric_set().

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

roc_auc_with_event_level <- function(data, truth, ..., na_rm = TRUE) {
   roc_auc(
      data = data,
      truth = !! rlang::enquo(truth),
      ...,
      na_rm = na_rm,
      # set event level
      event_level = "second"
   )
}

roc_auc_with_event_level <- new_prob_metric(roc_auc_with_event_level, "maximize")

ms <- metric_set(accuracy, roc_auc_with_event_level)
ms
#> # A tibble: 2 × 3
#>   metric                   class        direction
#>   <chr>                    <chr>        <chr>    
#> 1 accuracy                 class_metric maximize 
#> 2 roc_auc_with_event_level prob_metric  maximize

Created on 2021-08-01 by the reprex package (v2.0.0)

Now you can use this metric set ms in tuning functions like last_fit(metrics = ms).

Upvotes: 2

Related Questions