zesla
zesla

Reputation: 11813

extract contents from confusionMatrix saved in a list column in dplyr

As shown in code below, after cross validation, I'm trying to extract model metrics for each fold. I saved all predictions in resampling, group the data by folds, compute the confusion matrix for each group, and saved the confusion matrix object as a list column cm. Now I need to extract metrics info, like precision, etc from the objects saved in the column. My example code is shown below.

library(caret)
iris2 = iris %>% 
    filter(Species != 'setosa') %>%
    mutate(Species = factor(Species))

train.control <- trainControl(method="cv", 
                           number=5,
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions='all')
rf = train(Species~., data=iris2,  method = 'rf',
           metric = 'ROC', trControl=train.control)
rf$pred %>% group_by(Resample) %>%
    do(cm = confusionMatrix(.$pred, .$obs),
       Accuracy = map(cm, ~.x$byClass['Precision'])) 

I got error message:

Error in .x$byClass : $ operator is invalid for atomic vectors

I could not figure out why it does not work. My question is how I can modified the last line to make it work? Thanks

Upvotes: 1

Views: 409

Answers (1)

Vivek Katial
Vivek Katial

Reputation: 636

You can use ungroup() and then simply mutate the Accuracy by accessing the specific part of the list for each fold you use unlist() to extract the element itself.

rf$pred %>% 
  group_by(Resample) %>%
  do(cm = confusionMatrix(.$pred, .$obs)) %>% 
  ungroup() %>% 
  mutate(neg_pred_value = map(cm, ~ .x[["byClass"]][["Neg Pred Value"]]) %>% unlist(),
         accuracy = map(cm, ~ .x[["byClass"]][["Precision"]]) %>% unlist())

Using the code above we get the following output as a tibble

# A tibble: 5 x 4
  Resample                    cm neg_pred_value  accuracy
     <chr>                <list>          <dbl>     <dbl>
1    Fold1 <S3: confusionMatrix>      0.9090909 1.0000000
2    Fold2 <S3: confusionMatrix>      1.0000000 1.0000000
3    Fold3 <S3: confusionMatrix>      1.0000000 1.0000000
4    Fold4 <S3: confusionMatrix>      0.8181818 0.8888889
5    Fold5 <S3: confusionMatrix>      1.0000000 0.9090909

Upvotes: 1

Related Questions