Matthew Loh
Matthew Loh

Reputation: 147

Coding Multiple Models in a Function in R Tidyverse

I am trying to fit a couple of machine learning models with a few formulas and store them in a tibble as list_column objects.

I have tried to modify the code quoted in book 'R for Datascience' (Chapter 25 : Many Models) but it only gives me the last output. Please refer to below code for more details. We are using the gapminder dataset from the gapminder packages as an example.

lab_formula <- as.formula("pop ~ lifeExp ")

temp_formula <- as.formula("gdppercap ~ year")

formula_list <- list(lab_formula,temp_formula)
library(gapminder)

by_country <- gapminder %>% 
  dplyr :: group_by(country, continent) %>% 
  nest()

country_model <- function(df) {
for (i in formula_list) {
  lm(formula=formula[i], data = df)
  randomForest(formula=formula[i], data = df)
  gbm(formula=formula[i], data = df, n.minobsinnode = 2)
}
}

by_country <- by_country %>% 
  mutate(model = map(data, country_model))

by_country
# A tibble: 142 x 4
   country     continent data              model    
   <fct>       <fct>     <list>            <list>   
 1 Afghanistan Asia      <tibble [12 x 4]> <S3: gbm>
 2 Albania     Europe    <tibble [12 x 4]> <S3: gbm>
 3 Algeria     Africa    <tibble [12 x 4]> <S3: gbm>
 4 Angola      Africa    <tibble [12 x 4]> <S3: gbm>
 5 Argentina   Americas  <tibble [12 x 4]> <S3: gbm>
 6 Australia   Oceania   <tibble [12 x 4]> <S3: gbm>
 7 Austria     Europe    <tibble [12 x 4]> <S3: gbm>
 8 Bahrain     Asia      <tibble [12 x 4]> <S3: gbm>
 9 Bangladesh  Asia      <tibble [12 x 4]> <S3: gbm>
10 Belgium     Europe    <tibble [12 x 4]> <S3: gbm>
# ... with 132 more rows

There is no error code but it does not achieve my objective of training the 3 machine learning models (LM, RF, GBM) with the different variables.

Upvotes: 0

Views: 809

Answers (1)

Ronak Shah
Ronak Shah

Reputation: 388962

You need to think about how you want to store your results. Here is one way to do that. First create a list of formulas you want to apply

library(randomForest)
library(gbm)
library(tidyverse)

lab_formula <- as.formula("pop ~ lifeExp ")
temp_formula <- as.formula("gdpPercap ~ year")
formula_list <- list(lab_formula,temp_formula)

Create a function which returns a list of models applied only to one formula at a time.

country_model <- function(df, formula_list, index) {
    list(lm(formula = formula_list[[index]] , data = df), 
         randomForest(formula=formula_list[[index]], data = df),
         gbm(formula=formula_list[[index]], data = df, n.minobsinnode = 2))
}

and now apply it to each data passing the formula_list and formula number from the list that you want to apply to your data,

df1 <- by_country %>% 
  mutate(model1 = map(data, ~country_model(., formula_list, 1)), 
         model2 = map(data, ~country_model(., formula_list, 2)))
df1

# A tibble: 142 x 5
#   country     continent data              model1     model2    
#   <fct>       <fct>     <list>            <list>     <list>    
# 1 Afghanistan Asia      <tibble [12 × 4]> <list [3]> <list [3]>
# 2 Albania     Europe    <tibble [12 × 4]> <list [3]> <list [3]>
# 3 Algeria     Africa    <tibble [12 × 4]> <list [3]> <list [3]>
# 4 Angola      Africa    <tibble [12 × 4]> <list [3]> <list [3]>
# 5 Argentina   Americas  <tibble [12 × 4]> <list [3]> <list [3]>
# 6 Australia   Oceania   <tibble [12 × 4]> <list [3]> <list [3]>
# 7 Austria     Europe    <tibble [12 × 4]> <list [3]> <list [3]>
# 8 Bahrain     Asia      <tibble [12 × 4]> <list [3]> <list [3]>
# 9 Bangladesh  Asia      <tibble [12 × 4]> <list [3]> <list [3]>
#10 Belgium     Europe    <tibble [12 × 4]> <list [3]> <list [3]>
# … with 132 more rows

Now every row in model1 has a list of three models which used the formula formula_list[[1]] and similarly for model2 you have models which used formula formula_list[[2]].


To use these models for prediction we can need to treat randomForest model differently since it needs n.trees parameter and as we return those models from our function we know it is the third model in the list we can distinguish it based on the index.

df1 %>%
   mutate(pred= map2(data,model1, function(x, y) 
     map(seq_along(y), function(i) 
        if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
        else as.numeric(predict(y[[i]])))))

Upvotes: 1

Related Questions