Ismael
Ismael

Reputation: 3

Fastest way to compute this triple summation in R

My goal is to compute the following triple summation:

$V = \( \frac{1}{n1n2n3} \) \sum_{i=1}^{n1}\sum_{j=1}^{n2}\sum_{k=1}^{n3} I(Y_{1i},Y_{2j},Y_{3k})$

where I(Y1,Y2,Y3) is defined as:

I(Y1,Y2,Y3) = 1 if Y[1] < Y[2] < Y[3]
 I(Y1,Y2,Y3) = 1/2 if Y[1] = Y[2] < Y[3]
 I(Y1,Y2,Y3) = 1/6 if Y[1] = Y[2] = Y[3]
 I(Y1,Y2,Y3) = 0 Otherwise.

I have implemented the calculations with R and the code is:

The issue is that with this way the computations are so expensive.I guess that, that has to do with using expand.grid() to create the matrix of all combinations and then compute the Result.

Does anyone have any more efficient way to do this?

set.seed(123)

nclasses = 3

ind <- function(Y){
  res = 0


if (Y[1] < Y[2] & Y[2] < Y[3]){res = 1}
  else if (Y[1] == Y[2] & Y[2] < Y[3]){res = 1/2}
  else if (Y[1] == Y[2] & Y[2] == Y[3]){res = 1/6}
  else {res = 0}

  return (res)
}

N_obs = 300
c0 <- rnorm(N_obs)
l0 = length(c0)

c1 <- rnorm(N_obs)
l1 = length(c1)

c2 <- rnorm(N_obs)
l2 = length(c2)

mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
dim(mat)

Result <- (1/(l0*l1*l2))*sum(apply(mat, 1, ind))

Upvotes: 0

Views: 327

Answers (2)

Cole
Cole

Reputation: 11255

tl;dr - data.table using non-equi joins can solve it in the same amount of time that tidyr finished generating the data. Still, the tidyr / dplyr solution looks better.

data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
  ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
    ][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                      ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
                                      )))
      ] / (length(c0) * length(c1) * length(c2))

There are two speedups - how the data is generated and then the calculation itself.

Generating Data

The fastest way is to keep it simple. Instead of transposing and unlisting, you can use as.matrix for clarity and a slight speed bump. Or you can keep the expand.grid as a data.frame which would be similar to the tidyr solution which creates a tibble.

The data.table equivalent is CJ(c0, c1, c2) and is around 10 times faster than the fastest base or tidyr equivalent.

#Creating dataset
Unit: milliseconds
                expr     min      lq    mean  median      uq     max neval
            original 1185.10 1239.37 1478.46 1503.68 1690.47 1899.37    10
           as.matrix 1023.49 1041.72 1213.17 1198.24 1360.51 1420.78    10
         expand.grid  764.43  840.11 1030.13 1030.79 1146.82 1354.06    10
      tidyr_complete 2811.00 2948.86 3118.33 3158.59 3290.21 3364.52    10
      tidyr_crossing 1154.94 1171.01 1311.71 1233.40 1545.30 1609.86    10
       data.table_CJ  154.71  155.30  175.65  162.54  174.96  291.14    10

Another approach is to use non-equi joins or to pre-filter the data. We know that if c0 > c1 or c1 > c2 that the summation result would be 0. In this way, we can filter out combinations that we know we don't need to store to memory which creates the combos faster.

While both of these approaches are slower than data.table::CJ(), they set the stage better for the triple summation.

# 'data.table_CJ_filter' = CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
#'tidyr_cross_filter' =  crossing(c0, c1) %>% filter(c0 <= c1) %>% crossing(c2) %>% filter(c1 <= c2)

#Creating dataset with future calcs in mind
Unit: milliseconds
                 expr    min     lq   mean median      uq     max neval
  data.table_non_equi 358.41 360.35 373.95 374.57  383.62  400.42    10
 data.table_CJ_filter 515.50 517.99 605.06 527.63  661.54  856.43    10
   tidyr_cross_filter 776.91 783.35 980.19 928.25 1178.47 1287.91    10

Calculating the summation

@Jon Spring's solution is great. case_when and ifelse are vectorized whereas your original if ... else statements were not. I translated Jon's answer to Base R. It's faster than your original solution but still takes about 50% longer than dplyr.

One note is that if you did the non-equi join, you can further simplify the case_when because we already did the filtering - all the rows left get 1, 1/2, or 1/6. Note that the pre-filtered solutions are about anywhere between 10x - 30x faster than data that had not been pre-filtered.

Unit: milliseconds
             expr     min      lq    mean  median      uq     max neval
             base 5666.93 6003.87 6303.27 6214.58 6416.42 7423.30    10
            dplyr 3633.48 3963.47 4160.68 4178.15 4395.96 4530.15    10
       data.table  236.83  262.10  305.19  268.47  269.44  495.22    10
 dplyr_pre_filter  378.79  387.38  459.67  418.58  448.13  765.74    10

Putting it together

The final solution provided at the beginning takes less than a second. The dplyr revision which is less than 2 seconds. Both solutions rely on pre-filtering before going to the logical if ... else statement.

Unit: milliseconds
      expr     min      lq    mean  median      uq    max neval
    dt_res  589.83  608.26  736.34  642.46  760.18 1091.1    10
 dt_CJ_res  750.07  764.78  905.12  893.73 1040.21 1140.5    10
 dplyr_res 1156.69 1169.84 1363.82 1337.42 1496.60 1709.8    10

Data / Code

# https://stackoverflow.com/questions/56185072/fastest-way-to-compute-this-triple-summation-in-r
library(dplyr)
library(tidyr)
library(data.table)

options(digits = 5)
set.seed(123)

nclasses = 3
N_obs = 300

c0 <- rnorm(N_obs)
c1 <- rnorm(N_obs)
c2 <- rnorm(N_obs)

# Base R Data Generation --------------------------------------------------

mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
df <- expand.grid(c0,c1,c2)

identical(mat, unname(as.matrix(df))) #TRUE - names are different with as.matrix

# tidyr and data.table Data Generation ------------------------------------

tib <- crossing(c0, c1, c2) #faster than complete

tib2 <- crossing(c0, c1)%>% #faster but similar in concept to non-equi
  filter(c0 <= c1)%>%
  crossing(c2)%>%
  filter(c1 <= c2)

dt <-   data.table(c0
                   )[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
                     ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
                       ][c0 <= c1 & c1 <= c2, ]

# Base R summation --------------------------------------------------------

sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
                      ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
                             ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
                      ))
    ) / (length(c0)*length(c1)*length(c2))


# dplyr summation ---------------------------------------------------------

tib %>%
  mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                         c0 == c1 & c1 < c2  ~ 1/2,
                         c0 == c1 & c1 == c2 ~ 1/6,
                         TRUE               ~ 0)) %>%
  summarize(mean_res = mean(res))

# data.table summation ----------------------------------------------------

#why base doesn't have case_when, who knows
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
                ifelse(c0 == c1 & c1 < c2, 1/2,
                       ifelse(c0 == c1 & c1 == c2, 1/6)
                )))
   ] / (length(c0) * length(c1) * length(c2))


CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                             ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
                                             )))
             ] / (length(c0) * length(c1) * length(c2))

# Benchmarking ------------------------------------------------------------

library(microbenchmark)

# Data generation
microbenchmark('original' = {
  matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
}
, 'as.matrix' = {
  as.matrix(expand.grid(c0,c1,c2)) 
}
, 'expand.grid' = {
  expand.grid(c0,c1,c2) #keep it simpler
}
, 'tidyr_complete' = {
  tibble(c0, c1, c2) %>% complete(c0, c1, c2)
}
, 'tidyr_crossing' = {
  crossing(c0, c1, c2)
}
, 'data.table_CJ' = {
  CJ(c0,c1,c2)
}
, times = 10)

microbenchmark('data.table_non_equi' = {
  data.table(c0
             )[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
               ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
                 ][c0 <= c1 & c1 <= c2, ]
}
, 'data.table_CJ_filter' = {
  CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
}
, 'tidyr_cross_filter' = {
  crossing(c0,c1)%>%filter(c0 <= c1)%>% crossing(c2)%>% filter(c1 <= c2)
}
, times = 10
)

# Summation Calculation
microbenchmark('base' = {
  sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
             ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
                    ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
             ))
  ) / (length(c0)*length(c1)*length(c2))
}
, 'dplyr' = {
  tib %>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           c0 == c1 & c1 == c2 ~ 1/6,
                           TRUE               ~ 0)) %>%
    summarize(mean_res = mean(res))
}
, 'data.table' = {
  dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
                  ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
                  ))
     ] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_pre_filter' = {
  tib2 %>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           TRUE ~ 1/6)) %>%
    summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10)

# Start to Finish

microbenchmark('dt_res' = {
  data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
  ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
    ][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                      ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
    ))
    ] / (length(c0) * length(c1) * length(c2))
}
, 'dt_CJ_res' = {
  CJ(c0, c1, c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                                 ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
  ))
  ] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_res' = {
  crossing(c0, c1)%>% #faster but similar in concept to non-equi
    filter(c0 <= c1)%>%
    crossing(c2)%>%
    filter(c1 <= c2)%>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           TRUE ~ 1/6)) %>%
    summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10
)

Upvotes: 1

Jon Spring
Jon Spring

Reputation: 66480

The original took 399 seconds on my computer to execute the Result <- line. This variation using dplyr & tidyr took 7 seconds to do the summation part, and I get exactly the same answer. I presume the speedup comes from how the dplyr version is vectorized and can do the same calc across all 27 million rows, whereas the original is, I suspect, re-calculating something each time.

library(dplyr); library(tidyr)

combos <- tibble(Y1 = rnorm(300),
                 Y2 = rnorm(300),
                 Y3 = rnorm(300)) %>%
  complete(Y1, Y2, Y3)

combos %>%
  mutate(res = case_when(Y1  < Y2 & Y2 < Y3  ~ 1,
                         Y1 == Y2 & Y2 < Y3  ~ 1/2,
                         Y1 == Y2 & Y2 == Y3 ~ 1/6,
                         TRUE               ~ 0)) %>%
  summarize(mean_res = mean(res))

This seems also solveable algebraicly, but I presume the point of this was to solve through simulation.

If we have three separate sets of 300 numbers that are 16 digits long, each drawn using rnorm, it's an infinitesimal chance that any would match each other. So we can ignore the 2nd and 3rd cases, which don't occur with the suggested set.seed and might take billions of runs to encounter once.

Now how often does Y[1] < Y[2] < Y[3]? For any set of three different numbers, there are 6 ways to sort them, and since each of these numbers has the same distribution, any of these 6 orders is equally likely. Only 1 of the 6 (16.7%) is in ascending order, so we should expect to get 1 about 16.7% of the time, and 0 the other 83.3%. With set.seed(123), the ascending scenario arises in 22,379,120 out of 27,000,000 cases (82.9%).

Upvotes: 1

Related Questions