887
887

Reputation: 619

How to create nested training and testing sets?

I'm working with the ChickWeight data set in R. I'm looking to create multiple models, each trained for an individual chick. As such, I am nesting the data so that a dataframe is created for each individual chick and stored within the list column.

Here is the start:

library(tidyverse)
library(datasets)
data("ChickWeight")

ChickWeightNest <- ChickWeight %>% 
  group_by(Chick) %>% 
  nest()

From here, training a linear regression model on all dataframes simultaneously is very easy by simply building the model as a function then mutating a new column and mapping. However, building a more sophisticated model (e.g. xgboost) requires first splitting the data into testing and training sets. How can I split my all nested data frames at once to create training and testing sets so that I can train multiple models simultaneously?

As a side note, info on training/tuning multiple models seems to be relatively sparse in my research, any related resources or past stack questions would be very appreciated.

Upvotes: 4

Views: 255

Answers (2)

Quinten
Quinten

Reputation: 41337

Maybe you want something like this where you first randomly sample train or test per chick in a new column to use later and group again to nest the data per group:

library(dplyr)
library(tidyr)
library(datasets)
data("ChickWeight")

ChickWeight %>% 
  group_by(Chick) %>% 
  rowwise() %>%
  mutate(split = sample(c("train", "test"), n(), replace = FALSE)) %>%
  group_by(Chick) %>%
  nest() 
#> # A tibble: 50 × 2
#> # Groups:   Chick [50]
#>    Chick data             
#>    <ord> <list>           
#>  1 1     <tibble [12 × 4]>
#>  2 2     <tibble [12 × 4]>
#>  3 3     <tibble [12 × 4]>
#>  4 4     <tibble [12 × 4]>
#>  5 5     <tibble [12 × 4]>
#>  6 6     <tibble [12 × 4]>
#>  7 7     <tibble [12 × 4]>
#>  8 8     <tibble [11 × 4]>
#>  9 9     <tibble [12 × 4]>
#> 10 10    <tibble [12 × 4]>
#> # … with 40 more rows

Created on 2022-06-29 by the reprex package (v2.0.1)

Upvotes: 1

Roger-123
Roger-123

Reputation: 2520

The key here is realizing that each line of the nested data is a list and so you have to use list functions on it, for example lapply from base R or map from purrr.

Here's an example of how that would work using the rsample package to do the split (75% for training)

ChickWeightNest_example<- ChickWeightNest %>%
  mutate(data_split = purrr::map(data,
                       ~rsample::initial_split(.x, prop = .75))) %>%
  mutate(data_training_only= purrr::map(data_split,
                             ~rsample::training(.x)),
         data_testing_only= purrr::map(data_split, 
                             ~rsample::testing(.x))
  )

Upvotes: 1

Related Questions