Divi
Divi

Reputation: 1634

Fitting different models to each subset of data in R

I have a large dataset with multiple classes. My aim to fit a model to each class, and then predict the results and visualize them for each class in a facet.

For a reproducible example, I have created something basic using mtcars. This works well for a simple one regression model for each class.

mtcars = data.table(mtcars)
model = mtcars[, list(fit = list(lm(mpg~disp+hp+wt))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)

However, I would like to try something like below, which does not yet work. This try is with a list of formula, but I am also looking to send different models (some glms, a few trees) to each subset of data.

mtcars = data.table(mtcars)
factors = list(c("disp","wt"), c("disp"), c("hp"))
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))
model = mtcars[, list(fit = list(lm(form))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)

Upvotes: 5

Views: 1080

Answers (3)

Jonathan Carroll
Jonathan Carroll

Reputation: 3947

I'm actually doing almost exactly this at the moment, so perfect timing. This is going to be a "tidyverse"-heavy answer, but I really like the way it works.

purrr has some very handy map functions that make this incredibly smooth when combined with list-columns in tibble. Using your definitions (I'm not trying to optimise that)

library(data.table)
mtcars = data.table(mtcars)
factors = list(c("disp","wt"), c("disp"), c("hp"))
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))

which provides a list of functions, these can be passed to purrr::invoke_map which applies a list of arguments (which you have) to a list of functions (in your case, just lm, but I suspect expandable to others too) with optional arguments (in your example case, mtcars). Using tibble, these are stored as a neat data.frame-esque list, otherwise they are returned as lm objects

library(tibble)
library(purrr) 
models <- tibble(fit = invoke_map(lm, form, data = mtcars))
models
#> # A tibble: 3 x 1
#>          fit
#>       <list>
#>   1 <S3: lm>
#>   2 <S3: lm>
#>   3 <S3: lm>

The super-useful part comes when you want to do something to all of those elements, say, extract the fitted coefficients:

map(models$fit, coefficients)
#> [[1]]
#> (Intercept)        disp          wt 
#> 34.96055404 -0.01772474 -3.35082533 
#> 
#> [[2]]
#> (Intercept)        disp 
#> 29.59985476 -0.04121512 
#> 
#> [[3]]
#> (Intercept)          hp 
#> 30.09886054 -0.06822828 

or re-examine the formula used

map(models$fit, formula)
#> [[1]]
#> mpg ~ disp + wt
#> <environment: 0x0000000017ee73a8>
#>   
#>   [[2]]
#> mpg ~ disp
#> <environment: 0x0000000018392c58>
#>   
#>   [[3]]
#> mpg ~ hp
#> <environment: 0x0000000018471d18>

Furthermore, if you want to add some predictions from the models, this is easily achieved using broom::augment

library(broom)
models_with_predicts <- models %>% mutate(predict = map(fit, augment))
models_with_predicts
#> # A tibble: 3 x 2
#>          fit                predict
#>       <list>                 <list>
#>   1 <S3: lm> <data.frame [32 x 10]>
#>   2 <S3: lm>  <data.frame [32 x 9]>
#>   3 <S3: lm>  <data.frame [32 x 9]>

You can get back to a data-level (with predictions) by unnest()ing, but this will combine all of your data (add a grouping level to keep the fits separate)

library(tidyr)
unnest(models_with_predicts, predict)

#> # A tibble: 96 x 11
#> mpg  disp    wt  .fitted   .se.fit     .resid       .hat   .sigma     .cooksd .std.resid    hp
#> <dbl> <dbl> <dbl>    <dbl>     <dbl>      <dbl>      <dbl>    <dbl>       <dbl>      <dbl> <dbl>
#>   1   21.0 160.0 2.620 23.34543 0.6075520 -2.3454326 0.04339369 2.933379 0.010222201 -0.8222164    NA
#> 2   21.0 160.0 2.875 22.49097 0.6221836 -1.4909721 0.04550894 2.954135 0.004351414 -0.5232550    NA
#> 3   22.8 108.0 2.320 25.27237 0.7326015 -2.4723669 0.06309504 2.928665 0.017217431 -0.8757799    NA
#> 4   21.4 258.0 3.215 19.61467 0.5743205  1.7853334 0.03877647 2.948162 0.005241995  0.6243627    NA
#> 5   18.7 360.0 3.440 17.05281 1.0943208  1.6471930 0.14078260 2.949120 0.020275438  0.6092882    NA
#> 6   18.1 225.0 3.460 19.37863 0.6122393 -1.2786309 0.04406584 2.957872 0.003089406 -0.4483953    NA
#> 7   14.3 360.0 3.570 16.61720 0.9897465 -2.3171997 0.11516157 2.931444 0.030948880 -0.8446199    NA
#> 8   24.4 146.7 3.190 21.67120 0.9053245  2.7287988 0.09635365 2.918183 0.034431234  0.9842424    NA
#> 9   22.8 140.8 3.150 21.90981 0.9165259  0.8901898 0.09875274 2.962885 0.003775416  0.3215070    NA
#> 10  19.2 167.6 3.440 20.46305 0.9678618 -1.2630477 0.11012510 2.957375 0.008693734 -0.4590766    NA
#> # ... with 86 more rows

Upvotes: 1

Weihuang Wong
Weihuang Wong

Reputation: 13118

Here's an approach where we set up predict for each model as an unevaluated list, evaluate them within the data.table object, gather the output, and pass it into ggplot:

models = quote(list(
      predict(lm(form[[1]], .SD)),
      predict(lm(form[[2]], .SD)), 
      predict(lm(form[[3]], .SD))))

d <- mtcars
d[, c("est1", "est2", "est3") := eval(models), by = cyl]
d <- tidyr::gather(d, key = model, value = pred, est1:est3)

library(ggplot2)
ggplot(d, aes(x = mpg, y = pred)) + geom_line() + facet_grid(cyl ~ model)

Output:

enter image description here

Upvotes: 4

Arun
Arun

Reputation: 118799

lm() accepts formula as a character vector as well. I'd therefore simply create form as:

form = lapply(factors, function(x) paste("mpg~", paste(x, collapse="+")))

And, you will need to supply the correct data (corresponding to each group using the inbuilt special symbol .SD):

model = mtcars[, list(fit=lapply(form, lm, data=.SD)), keyby=cyl]

For each cyl, form is looped through, and the corresponding formula is passed as the first argument to lm each time along with data = .SD, where .SD stands for Subset of Data and is itself a data.table. You can read more about it from the vignettes.


If you also want to have the formula in the result, then:

chform = unlist(form)
model = mtcars[, list(form=chform, fit=lapply(form, lm, data=.SD)), keyby = cyl]

HTH

PS: Please read this post if you plan to use update() within [...] using data.tables.

Upvotes: 3

Related Questions