Doug Fir
Doug Fir

Reputation: 21292

Each fold in k fold has substantially more than n/k rows?

I'm using rsample to generate folds for cross validation. My understanding is that cross validation splits training data into k folds. However, with my code block below each fold has the same number of rows as training data:

library(rsample)
diamonds %>% dim # diamonds has ~54K rows
set.seed(123)
diamonds_split <- initial_split(diamonds, 0.9)
training_data <- training(diamonds_split)
testing_data <- testing(diamonds_split)

# 5 fold split
train_cv <- vfold_cv(training_data, 5)
train_cv # each fold has ~39K, expected roughly (0.9 * 54K) / 5 each fold ~ 9.7K
#  5-fold cross-validation 
# A tibble: 5 x 2
  splits               id   
  <named list>         <chr>
1 <split [38.8K/9.7K]> Fold1
2 <split [38.8K/9.7K]> Fold2
3 <split [38.8K/9.7K]> Fold3
4 <split [38.8K/9.7K]> Fold4
5 <split [38.8K/9.7K]> Fold5

Each fold has 38.8K rows. Diamonds dataset only has 54K to begin with. If 0.9 of diamonds is my training set, I expected ((0.9 * 54K) / 5) each fold ~ 9.7K, not 38.8K.

Is my understanding of cross validation flawed or have I made an error in my code block?

Upvotes: 1

Views: 77

Answers (1)

Gregor Thomas
Gregor Thomas

Reputation: 146090

Your understanding of k-fold validation is flawed. One fold is left out per iteration. 0.9 * 54k = 48.6k training rows. With 5 folds, you use 4/5 of those rows per iteration (with the last 1/5 being used as a validation set for that iteration). 48.6 * 4/5 = 38.88, with the 9.7k balance as the validation set.

Upvotes: 3

Related Questions