Reputation: 3
My goal is to compute the following triple summation:
$V = \( \frac{1}{n1n2n3} \) \sum_{i=1}^{n1}\sum_{j=1}^{n2}\sum_{k=1}^{n3} I(Y_{1i},Y_{2j},Y_{3k})$
where I(Y1,Y2,Y3) is defined as:
I(Y1,Y2,Y3) = 1 if Y[1] < Y[2] < Y[3]
I(Y1,Y2,Y3) = 1/2 if Y[1] = Y[2] < Y[3]
I(Y1,Y2,Y3) = 1/6 if Y[1] = Y[2] = Y[3]
I(Y1,Y2,Y3) = 0 Otherwise.
I have implemented the calculations with R and the code is:
The issue is that with this way the computations are so expensive.I guess that, that has to do with using expand.grid()
to create the matrix of all combinations and then compute the Result.
Does anyone have any more efficient way to do this?
set.seed(123)
nclasses = 3
ind <- function(Y){
res = 0
if (Y[1] < Y[2] & Y[2] < Y[3]){res = 1}
else if (Y[1] == Y[2] & Y[2] < Y[3]){res = 1/2}
else if (Y[1] == Y[2] & Y[2] == Y[3]){res = 1/6}
else {res = 0}
return (res)
}
N_obs = 300
c0 <- rnorm(N_obs)
l0 = length(c0)
c1 <- rnorm(N_obs)
l1 = length(c1)
c2 <- rnorm(N_obs)
l2 = length(c2)
mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
dim(mat)
Result <- (1/(l0*l1*l2))*sum(apply(mat, 1, ind))
Upvotes: 0
Views: 327
Reputation: 11255
tl;dr - data.table using non-equi joins can solve it in the same amount of time that tidyr
finished generating the data. Still, the tidyr
/ dplyr
solution looks better.
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
)))
] / (length(c0) * length(c1) * length(c2))
There are two speedups - how the data is generated and then the calculation itself.
The fastest way is to keep it simple. Instead of transposing and unlisting, you can use as.matrix
for clarity and a slight speed bump. Or you can keep the expand.grid
as a data.frame which would be similar to the tidyr
solution which creates a tibble.
The data.table equivalent is CJ(c0, c1, c2)
and is around 10 times faster than the fastest base or tidyr equivalent.
#Creating dataset
Unit: milliseconds
expr min lq mean median uq max neval
original 1185.10 1239.37 1478.46 1503.68 1690.47 1899.37 10
as.matrix 1023.49 1041.72 1213.17 1198.24 1360.51 1420.78 10
expand.grid 764.43 840.11 1030.13 1030.79 1146.82 1354.06 10
tidyr_complete 2811.00 2948.86 3118.33 3158.59 3290.21 3364.52 10
tidyr_crossing 1154.94 1171.01 1311.71 1233.40 1545.30 1609.86 10
data.table_CJ 154.71 155.30 175.65 162.54 174.96 291.14 10
Another approach is to use non-equi joins or to pre-filter the data. We know that if c0 > c1
or c1 > c2
that the summation result would be 0. In this way, we can filter out combinations that we know we don't need to store to memory which creates the combos faster.
While both of these approaches are slower than data.table::CJ()
, they set the stage better for the triple summation.
# 'data.table_CJ_filter' = CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
#'tidyr_cross_filter' = crossing(c0, c1) %>% filter(c0 <= c1) %>% crossing(c2) %>% filter(c1 <= c2)
#Creating dataset with future calcs in mind
Unit: milliseconds
expr min lq mean median uq max neval
data.table_non_equi 358.41 360.35 373.95 374.57 383.62 400.42 10
data.table_CJ_filter 515.50 517.99 605.06 527.63 661.54 856.43 10
tidyr_cross_filter 776.91 783.35 980.19 928.25 1178.47 1287.91 10
@Jon Spring's solution is great. case_when
and ifelse
are vectorized whereas your original if ... else
statements were not. I translated Jon's answer to Base R. It's faster than your original solution but still takes about 50% longer than dplyr
.
One note is that if you did the non-equi join, you can further simplify the case_when
because we already did the filtering - all the rows left get 1, 1/2, or 1/6. Note that the pre-filtered solutions are about anywhere between 10x - 30x faster than data that had not been pre-filtered.
Unit: milliseconds
expr min lq mean median uq max neval
base 5666.93 6003.87 6303.27 6214.58 6416.42 7423.30 10
dplyr 3633.48 3963.47 4160.68 4178.15 4395.96 4530.15 10
data.table 236.83 262.10 305.19 268.47 269.44 495.22 10
dplyr_pre_filter 378.79 387.38 459.67 418.58 448.13 765.74 10
The final solution provided at the beginning takes less than a second. The dplyr
revision which is less than 2 seconds. Both solutions rely on pre-filtering before going to the logical if ... else
statement.
Unit: milliseconds
expr min lq mean median uq max neval
dt_res 589.83 608.26 736.34 642.46 760.18 1091.1 10
dt_CJ_res 750.07 764.78 905.12 893.73 1040.21 1140.5 10
dplyr_res 1156.69 1169.84 1363.82 1337.42 1496.60 1709.8 10
Data / Code
# https://stackoverflow.com/questions/56185072/fastest-way-to-compute-this-triple-summation-in-r
library(dplyr)
library(tidyr)
library(data.table)
options(digits = 5)
set.seed(123)
nclasses = 3
N_obs = 300
c0 <- rnorm(N_obs)
c1 <- rnorm(N_obs)
c2 <- rnorm(N_obs)
# Base R Data Generation --------------------------------------------------
mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
df <- expand.grid(c0,c1,c2)
identical(mat, unname(as.matrix(df))) #TRUE - names are different with as.matrix
# tidyr and data.table Data Generation ------------------------------------
tib <- crossing(c0, c1, c2) #faster than complete
tib2 <- crossing(c0, c1)%>% #faster but similar in concept to non-equi
filter(c0 <= c1)%>%
crossing(c2)%>%
filter(c1 <= c2)
dt <- data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, ]
# Base R summation --------------------------------------------------------
sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
))
) / (length(c0)*length(c1)*length(c2))
# dplyr summation ---------------------------------------------------------
tib %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
c0 == c1 & c1 == c2 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
# data.table summation ----------------------------------------------------
#why base doesn't have case_when, who knows
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2,
ifelse(c0 == c1 & c1 == c2, 1/6)
)))
] / (length(c0) * length(c1) * length(c2))
CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
)))
] / (length(c0) * length(c1) * length(c2))
# Benchmarking ------------------------------------------------------------
library(microbenchmark)
# Data generation
microbenchmark('original' = {
matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
}
, 'as.matrix' = {
as.matrix(expand.grid(c0,c1,c2))
}
, 'expand.grid' = {
expand.grid(c0,c1,c2) #keep it simpler
}
, 'tidyr_complete' = {
tibble(c0, c1, c2) %>% complete(c0, c1, c2)
}
, 'tidyr_crossing' = {
crossing(c0, c1, c2)
}
, 'data.table_CJ' = {
CJ(c0,c1,c2)
}
, times = 10)
microbenchmark('data.table_non_equi' = {
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, ]
}
, 'data.table_CJ_filter' = {
CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
}
, 'tidyr_cross_filter' = {
crossing(c0,c1)%>%filter(c0 <= c1)%>% crossing(c2)%>% filter(c1 <= c2)
}
, times = 10
)
# Summation Calculation
microbenchmark('base' = {
sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
))
) / (length(c0)*length(c1)*length(c2))
}
, 'dplyr' = {
tib %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
c0 == c1 & c1 == c2 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
}
, 'data.table' = {
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_pre_filter' = {
tib2 %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
TRUE ~ 1/6)) %>%
summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10)
# Start to Finish
microbenchmark('dt_res' = {
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dt_CJ_res' = {
CJ(c0, c1, c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_res' = {
crossing(c0, c1)%>% #faster but similar in concept to non-equi
filter(c0 <= c1)%>%
crossing(c2)%>%
filter(c1 <= c2)%>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
TRUE ~ 1/6)) %>%
summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10
)
Upvotes: 1
Reputation: 66480
The original took 399 seconds on my computer to execute the Result <-
line. This variation using dplyr
& tidyr
took 7 seconds to do the summation part, and I get exactly the same answer. I presume the speedup comes from how the dplyr
version is vectorized and can do the same calc across all 27 million rows, whereas the original is, I suspect, re-calculating something each time.
library(dplyr); library(tidyr)
combos <- tibble(Y1 = rnorm(300),
Y2 = rnorm(300),
Y3 = rnorm(300)) %>%
complete(Y1, Y2, Y3)
combos %>%
mutate(res = case_when(Y1 < Y2 & Y2 < Y3 ~ 1,
Y1 == Y2 & Y2 < Y3 ~ 1/2,
Y1 == Y2 & Y2 == Y3 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
This seems also solveable algebraicly, but I presume the point of this was to solve through simulation.
If we have three separate sets of 300 numbers that are 16 digits long, each drawn using rnorm, it's an infinitesimal chance that any would match each other. So we can ignore the 2nd and 3rd cases, which don't occur with the suggested set.seed
and might take billions of runs to encounter once.
Now how often does Y[1] < Y[2] < Y[3]? For any set of three different numbers, there are 6 ways to sort them, and since each of these numbers has the same distribution, any of these 6 orders is equally likely. Only 1 of the 6 (16.7%) is in ascending order, so we should expect to get 1 about 16.7% of the time, and 0 the other 83.3%. With set.seed(123)
, the ascending scenario arises in 22,379,120 out of 27,000,000 cases (82.9%).
Upvotes: 1