89_Simple
89_Simple

Reputation: 3805

group_by operation in dplyr vs data.table for fast implementation

dat <- data.frame(yearID = rep(1:10000, each = 12),
                  monthID = rep(1:12, times = 10000),
                  x1 = rnorm(120000),
                  x2 = rnorm(120000),
                  x3 = rnorm(120000),
                  x4 = rnorm(120000),
                  x5 = rnorm(120000),
                  x6 = rnorm(120000),
                  p.start = 6,
                  p.end = 7,
                  m.start = 8,
                  m.end = 9,
                  h.start = 10,
                  h.end = 11)

I need to do some operations on the above data which is described below after my current solution

library(dplyr)

start_time <- Sys.time()

df1 <- dat %>% 
       tidyr::gather(., index_name, value, x1:x6) %>%
       dplyr::filter(!index_name %in% c('x5','x6')) %>%
       dplyr::group_by(yearID, index_name) %>%
       dplyr::summarise(p.start.val = sum(value[monthID == p.start]),
                        p.val = sum(value[monthID >= p.start & monthID <= p.end]),
                        m.val = sum(value[monthID >= m.start & monthID <= m.end]),
                        h.val = sum(value[monthID >= h.start & monthID <= h.end]),
                        h.end.val = sum(value[monthID == h.end])) %>%
       tidyr::gather(., variable, value, p.start.val:h.end.val) %>%
       dplyr::mutate(new.col.name = paste0(index_name,'_',variable)) %>%
       dplyr::select(-index_name, -variable) %>% 
       tidyr::spread(., new.col.name, value) %>%
       dplyr::mutate(yearRef = 2018)

colnames(df1) <-  sub(".val", "", colnames(df1))    

df2 <- dat %>% 
       tidyr::gather(., index_name, value, x1:x6) %>%
       dplyr::filter(index_name %in% c('x4','x6')) %>%
       dplyr::group_by(yearID, index_name) %>%
       dplyr::summarise(p.end.val = value[monthID == p.end],
                        m.end.val = value[monthID == m.end],
                        h.end.val = value[monthID == h.end]) %>%
       tidyr::gather(., variable, value, p.end.val:h.end.val) %>%
       dplyr::mutate(new.col.name = paste0(index_name,'_',variable)) %>%
       dplyr::select(-index_name, -variable) %>% 
       tidyr::spread(., new.col.name, value) %>%
       dplyr::mutate(yearRef = 2018)

colnames(df2) <-  sub(".val", "", colnames(df2))

final.dat <- Reduce(function(...) merge(..., by = c( "yearID", "yearRef"), all.x=TRUE), list(df1,df2))

 end_time <- Sys.time()

 end_time - start_time

 # Time difference of 2.054761 secs

What I want to do is:

My code above works fine but takes quite a time if the size of dat increases i.e. if number of years become 20000 instead of 10000. I am wondering if someone could help me with a data.table to implement the above solution which I hope would make this faster. Thank you.

Upvotes: 1

Views: 213

Answers (1)

r2evans
r2evans

Reputation: 160417

I'll run this on df1 only, since from there the pattern is easily repeatable.

Notes:

  • I'm using magrittr solely to help break out each step in the chain, as each of dplyr's *verbsare directly translatable. It is not difficult to convert this into a non-magrittr` pipeline. The benefit of using it (as is also the case with tidyverse pipes) is, in my opinion, readability and therefore maintainability.

The Answer

I'll walk through the steps below.

library(data.table)
library(magrittr)

as.data.table(dat) %>%
  melt(., measure.vars = grep("^x[0-9]+", colnames(.)),
       variable.name = "index_name", variable.factor = FALSE) %>%
  .[ !index_name %in% c("x5", "x6"), ] %>%
  .[, .(
    p.start.val = sum(value[monthID == p.start]),
    p.val = sum(value[monthID >= p.start & monthID <= p.end]),
    m.val = sum(value[monthID >= m.start & monthID <= m.end]),
    h.val = sum(value[monthID >= h.start & monthID <= h.end]),
    h.end.val = sum(value[monthID == h.end])
  ), by = .(yearID, index_name) ] %>%
  melt(., id.vars = 1:2, variable.factor = FALSE) %>%
  .[, new.col.name := paste0(index_name, "_", variable) ] %>%
  .[, c("index_name", "variable") := NULL ] %>%
  dcast(., yearID ~ new.col.name) %>%
  .[, yearRef := 2018 ]

Steps:

Notes for the steps:

  • In the walk-through, I add dplyr::arrange_all() and .[order(.),] to the end of each intermediate pipe just so that we have apples-to-apples comparisons.

  • You did not include a random seed for your sample. I used set.seed(42), so to compare your console with what I'm showing, you'll need to set this seed and regenerate dat.

  • Each code block continues from the previous step's code, I shorten all repeated code to ... %>% for brevity to make this answer much less voluminous.

The steps:

  1. tidyr::gather to data.table::melt. There's likely a better way than grep to select ranges of columns in data.table::melt, but while as.data.table(dat)[, -(x1:x6)] works as one might infer, the same column-ranging does not work within melt.

    dat %>% 
      tidyr::gather(., index_name, value, x1:x6) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 10
    #   yearID monthID p.start p.end m.start m.end h.start h.end index_name  value
    #    <int>   <int>   <dbl> <dbl>   <dbl> <dbl>   <dbl> <dbl> <chr>       <dbl>
    # 1      1       1       6     7       8     9      10    11 x1          1.37 
    # 2      1       1       6     7       8     9      10    11 x2         -0.483
    # 3      1       1       6     7       8     9      10    11 x3         -0.314
    # 4      1       1       6     7       8     9      10    11 x4         -2.23 
    # 5      1       1       6     7       8     9      10    11 x5         -0.717
    # 6      1       1       6     7       8     9      10    11 x6         -1.04 
    as.data.table(dat) %>%
      melt(., measure.vars = grep("^x[0-9]+", colnames(.)),
           variable.name = "index_name", variable.factor = FALSE) %>%
      .[order(.),] %>% head() # just for comparison
    #    yearID monthID p.start p.end m.start m.end h.start h.end index_name      value
    # 1:      1       1       6     7       8     9      10    11         x1  1.3709584
    # 2:      1       1       6     7       8     9      10    11         x2 -0.4831687
    # 3:      1       1       6     7       8     9      10    11         x3 -0.3139498
    # 4:      1       1       6     7       8     9      10    11         x4 -2.2323282
    # 5:      1       1       6     7       8     9      10    11         x5 -0.7167575
    # 6:      1       1       6     7       8     9      10    11         x6 -1.0357630
    
    
  2. add in dplyr::filter and dplyr::summarise (grouped); I literally just copied the new variables' assignments from summarise(...) into .( ... ) block, no change was necessary.

    ... %>%
      dplyr::filter(!index_name %in% c('x5','x6')) %>%
      dplyr::group_by(yearID, index_name) %>%
      dplyr::summarise(p.start.val = sum(value[monthID == p.start]),
                       p.val = sum(value[monthID >= p.start & monthID <= p.end]),
                       m.val = sum(value[monthID >= m.start & monthID <= m.end]),
                       h.val = sum(value[monthID >= h.start & monthID <= h.end]),
                       h.end.val = sum(value[monthID == h.end])) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 7
    # # Groups:   yearID [2]
    #   yearID index_name p.start.val  p.val   m.val  h.val h.end.val
    #    <int> <chr>            <dbl>  <dbl>   <dbl>  <dbl>     <dbl>
    # 1      1 x1             -0.106   1.41   1.92    1.24      1.30 
    # 2      1 x2              0.573  -0.516 -2.29   -3.54     -0.990
    # 3      1 x3              0.767   0.455  0.461   2.28      2.08 
    # 4      1 x4             -0.0559 -1.11  -0.0975 -0.326    -0.483
    # 5      2 x1             -2.66   -5.10   1.01   -1.95     -0.172
    # 6      2 x2              0.342  -0.546  0.605   1.51      1.25 
    ... %>%
      .[ !index_name %in% c("x5", "x6"), ] %>%
      .[, .(
        p.start.val = sum(value[monthID == p.start]),
        p.val = sum(value[monthID >= p.start & monthID <= p.end]),
        m.val = sum(value[monthID >= m.start & monthID <= m.end]),
        h.val = sum(value[monthID >= h.start & monthID <= h.end]),
        h.end.val = sum(value[monthID == h.end])
      ), by = .(yearID, index_name) ] %>%
      .[order(.),] %>% head(.) # just for comparison
    #    yearID index_name p.start.val      p.val       m.val      h.val  h.end.val
    # 1:      1         x1 -0.10612452  1.4053975  1.92376468  1.2421556  1.3048697
    # 2:      1         x2  0.57306337 -0.5164756 -2.28861552 -3.5367198 -0.9901743
    # 3:      1         x3  0.76706512  0.4546020  0.46096277  2.2819246  2.0842981
    # 4:      1         x4 -0.05589648 -1.1093361 -0.09748514 -0.3260778 -0.4825699
    # 5:      2         x1 -2.65645542 -5.0969223  1.01347475 -1.9532258 -0.1719174
    # 6:      2         x2  0.34227065 -0.5457969  0.60537738  1.5136450  1.2498633
    
  3. tidyr::gather again

    ... %>%
      tidyr::gather(., variable, value, p.start.val:h.end.val) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 4
    # # Groups:   yearID [1]
    #   yearID index_name variable     value
    #    <int> <chr>      <chr>        <dbl>
    # 1      1 x1         h.end.val    1.30 
    # 2      1 x1         h.val        1.24 
    # 3      1 x1         m.val        1.92 
    # 4      1 x1         p.start.val -0.106
    # 5      1 x1         p.val        1.41 
    # 6      1 x2         h.end.val   -0.990
    ... %>%
      melt(., id.vars = 1:2, variable.factor = FALSE) %>%
      .[order(.),] %>% head(.) # just for comparison
    #    yearID index_name    variable      value
    # 1:      1         x1   h.end.val  1.3048697
    # 2:      1         x1       h.val  1.2421556
    # 3:      1         x1       m.val  1.9237647
    # 4:      1         x1 p.start.val -0.1061245
    # 5:      1         x1       p.val  1.4053975
    # 6:      1         x2   h.end.val -0.9901743
    
  4. tidyr::spread to data.table::dcast

    ... %>%
      dplyr::mutate(new.col.name = paste0(index_name,'_',variable)) %>%
      dplyr::select(-index_name, -variable) %>% 
      tidyr::spread(., new.col.name, value) %>%
      arrange_all() %>% head() # just for comparison
    # # A tibble: 6 x 21
    # # Groups:   yearID [6]
    #   yearID x1_h.end.val x1_h.val x1_m.val x1_p.start.val x1_p.val x2_h.end.val x2_h.val x2_m.val x2_p.start.val x2_p.val x3_h.end.val x3_h.val x3_m.val x3_p.start.val x3_p.val x4_h.end.val x4_h.val x4_m.val x4_p.start.val x4_p.val
    #    <int>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>        <dbl>    <dbl>    <dbl>          <dbl>    <dbl>
    # 1      1        1.30     1.24     1.92          -0.106    1.41        -0.990   -3.54   -2.29            0.573   -0.516        2.08     2.28     0.461          0.767    0.455      -0.483   -0.326   -0.0975        -0.0559  -1.11  
    # 2      2       -0.172   -1.95     1.01          -2.66    -5.10         1.25     1.51    0.605           0.342   -0.546       -1.38    -0.731    0.443         -0.725   -1.17       -0.623   -1.91     1.49          -0.806   -0.717 
    # 3      3        0.505   -0.104    1.74          -0.640   -0.185        0.570    1.68   -2.24           -0.103   -1.02        -1.36    -2.50    -0.918          1.36     1.26        0.0847  -0.280    0.699          0.114   -0.582 
    # 4      4       -0.811   -0.379   -2.09          -0.361    0.397       -0.782    0.110  -0.0187         -0.641   -0.149       -1.47    -2.45    -1.27           0.418    0.131       0.0582   0.885    0.784          0.998   -0.0115
    # 5      5       -2.99    -2.90     0.956          0.643    0.733        0.165    0.382   1.46            1.48     2.16        -0.451   -0.213   -0.357          0.222    0.686      -0.949   -0.156    1.23           1.35     0.908 
    # 6      6       -1.04    -0.322    1.96           1.30     1.64         0.838   -0.406   1.86            0.863    2.11         0.479    2.37    -1.13          -1.22    -1.63       -0.970    0.0391  -1.08           0.683   -1.24  
    ... %>%
      .[, new.col.name := paste0(index_name, "_", variable) ] %>%
      .[, c("index_name", "variable") := NULL ] %>%
      dcast(., yearID ~ new.col.name) %>%
      .[order(.),] %>% head(.) # just for comparison
    #    yearID x1_h.end.val   x1_h.val   x1_m.val x1_p.start.val   x1_p.val x2_h.end.val   x2_h.val    x2_m.val x2_p.start.val   x2_p.val x3_h.end.val   x3_h.val   x3_m.val x3_p.start.val   x3_p.val x4_h.end.val    x4_h.val    x4_m.val x4_p.start.val    x4_p.val
    # 1:      1    1.3048697  1.2421556  1.9237647     -0.1061245  1.4053975   -0.9901743 -3.5367198 -2.28861552      0.5730634 -0.5164756    2.0842981  2.2819246  0.4609628      0.7670651  0.4546020  -0.48256993 -0.32607779 -0.09748514    -0.05589648 -1.10933614
    # 2:      2   -0.1719174 -1.9532258  1.0134748     -2.6564554 -5.0969223    1.2498633  1.5136450  0.60537738      0.3422707 -0.5457969   -1.3790815 -0.7305400  0.4429124     -0.7249950 -1.1681343  -0.62293711 -1.90725766  1.48980773    -0.80634526 -0.71692479
    # 3:      3    0.5049551 -0.1039713  1.7399409     -0.6399949 -0.1845448    0.5697303  1.6768675 -2.24285021     -0.1029872 -1.0245616   -1.3608773 -2.5029906 -0.9178704      1.3641160  1.2619892   0.08468983 -0.27967757  0.69899862     0.11429665 -0.58216791
    # 4:      4   -0.8113932 -0.3785752 -2.0949859     -0.3610573  0.3971059   -0.7823128  0.1098614 -0.01867344     -0.6414615 -0.1488759   -1.4653210 -2.4476336 -1.2718183      0.4179297  0.1311655   0.05823201  0.88484095  0.78382293     0.99795594 -0.01147192
    # 5:      5   -2.9930901 -2.9032572  0.9558396      0.6428993  0.7326600    0.1645109  0.3819658  1.45532687      1.4820236  2.1608213   -0.4513016 -0.2129462 -0.3572757      0.2221201  0.6855960  -0.94859958 -0.15646638  1.23051588     1.34645936  0.90755241
    # 6:      6   -1.0431189 -0.3222408  1.9592347      1.3025426  1.6383908    0.8379162 -0.4059827  1.86142674      0.8626753  2.1076609    0.4792767  2.3683451 -1.1252801     -1.2213407 -1.6339743  -0.96979464  0.03912882 -1.08199221     0.68254513 -1.23950872
    
  5. finish it up

    df1a <- df1 %>% arrange_all()
    head(df1a)
    # # A tibble: 6 x 22
    # # Groups:   yearID [6]
    #   yearID x1_h.end   x1_h   x1_m x1_p.start   x1_p x2_h.end   x2_h    x2_m x2_p.start   x2_p x3_h.end   x3_h   x3_m x3_p.start   x3_p x4_h.end    x4_h    x4_m x4_p.start    x4_p yearRef
    #    <int>    <dbl>  <dbl>  <dbl>      <dbl>  <dbl>    <dbl>  <dbl>   <dbl>      <dbl>  <dbl>    <dbl>  <dbl>  <dbl>      <dbl>  <dbl>    <dbl>   <dbl>   <dbl>      <dbl>   <dbl>   <dbl>
    # 1      1    1.30   1.24   1.92      -0.106  1.41    -0.990 -3.54  -2.29        0.573 -0.516    2.08   2.28   0.461      0.767  0.455  -0.483  -0.326  -0.0975    -0.0559 -1.11      2018
    # 2      2   -0.172 -1.95   1.01      -2.66  -5.10     1.25   1.51   0.605       0.342 -0.546   -1.38  -0.731  0.443     -0.725 -1.17   -0.623  -1.91    1.49      -0.806  -0.717     2018
    # 3      3    0.505 -0.104  1.74      -0.640 -0.185    0.570  1.68  -2.24       -0.103 -1.02    -1.36  -2.50  -0.918      1.36   1.26    0.0847 -0.280   0.699      0.114  -0.582     2018
    # 4      4   -0.811 -0.379 -2.09      -0.361  0.397   -0.782  0.110 -0.0187     -0.641 -0.149   -1.47  -2.45  -1.27       0.418  0.131   0.0582  0.885   0.784      0.998  -0.0115    2018
    # 5      5   -2.99  -2.90   0.956      0.643  0.733    0.165  0.382  1.46        1.48   2.16    -0.451 -0.213 -0.357      0.222  0.686  -0.949  -0.156   1.23       1.35    0.908     2018
    # 6      6   -1.04  -0.322  1.96       1.30   1.64     0.838 -0.406  1.86        0.863  2.11     0.479  2.37  -1.13      -1.22  -1.63   -0.970   0.0391 -1.08       0.683  -1.24      2018
    df1b <- ... %>%
      .[, yearRef := 2018 ] %>%
      .[order(.),]
    head(df1b)
    #    yearID x1_h.end.val   x1_h.val   x1_m.val x1_p.start.val   x1_p.val x2_h.end.val   x2_h.val    x2_m.val x2_p.start.val   x2_p.val x3_h.end.val   x3_h.val   x3_m.val x3_p.start.val   x3_p.val x4_h.end.val    x4_h.val    x4_m.val x4_p.start.val    x4_p.val yearRef
    # 1:      1    1.3048697  1.2421556  1.9237647     -0.1061245  1.4053975   -0.9901743 -3.5367198 -2.28861552      0.5730634 -0.5164756    2.0842981  2.2819246  0.4609628      0.7670651  0.4546020  -0.48256993 -0.32607779 -0.09748514    -0.05589648 -1.10933614    2018
    # 2:      2   -0.1719174 -1.9532258  1.0134748     -2.6564554 -5.0969223    1.2498633  1.5136450  0.60537738      0.3422707 -0.5457969   -1.3790815 -0.7305400  0.4429124     -0.7249950 -1.1681343  -0.62293711 -1.90725766  1.48980773    -0.80634526 -0.71692479    2018
    # 3:      3    0.5049551 -0.1039713  1.7399409     -0.6399949 -0.1845448    0.5697303  1.6768675 -2.24285021     -0.1029872 -1.0245616   -1.3608773 -2.5029906 -0.9178704      1.3641160  1.2619892   0.08468983 -0.27967757  0.69899862     0.11429665 -0.58216791    2018
    # 4:      4   -0.8113932 -0.3785752 -2.0949859     -0.3610573  0.3971059   -0.7823128  0.1098614 -0.01867344     -0.6414615 -0.1488759   -1.4653210 -2.4476336 -1.2718183      0.4179297  0.1311655   0.05823201  0.88484095  0.78382293     0.99795594 -0.01147192    2018
    # 5:      5   -2.9930901 -2.9032572  0.9558396      0.6428993  0.7326600    0.1645109  0.3819658  1.45532687      1.4820236  2.1608213   -0.4513016 -0.2129462 -0.3572757      0.2221201  0.6855960  -0.94859958 -0.15646638  1.23051588     1.34645936  0.90755241    2018
    # 6:      6   -1.0431189 -0.3222408  1.9592347      1.3025426  1.6383908    0.8379162 -0.4059827  1.86142674      0.8626753  2.1076609    0.4792767  2.3683451 -1.1252801     -1.2213407 -1.6339743  -0.96979464  0.03912882 -1.08199221     0.68254513 -1.23950872    2018
    

They do match:

identical(as.data.frame(df1a), as.data.frame(df1b))
# [1] TRUE

The speed-ups are not gargantuan, but they do appear to be significant. One way you can speed up your own code (still dplyr) is to remove the grouping as soon as you don't need it. If I add ungroup() immediately after summarise(...), I see a small improvement.

microbenchmark::microbenchmark(
  dplyr = { ... },
  dplyr_ungrp = { ... },
  data.table = { ... },
  times = 10
)
# Unit: milliseconds
#         expr      min        lq      mean    median        uq       max neval
#        dplyr 988.8311 1021.4725 1048.5462 1045.6885 1066.2733 1135.6032    10
#  dplyr_ungrp 909.3643  913.9301  952.6282  937.6540  998.2802 1041.2144    10
#   data.table 457.4500  465.1788  478.1471  474.2388  478.9840  531.1449    10

Upvotes: 2

Related Questions