Saurabh
Saurabh

Reputation: 1626

Reference a previous value at runtime in data.table

I have a data.table as follows

library(data.table)
data = structure(list(value = c(54.71, 62.48, 57.88, 60.64, 56, 54.28, 
55.22, 63.77, 64.47, 67.36, 64.45, 64.7, 65.15, 62.19, 70.25, 
75.47, 79.75, 75.75, 75.24, 76, 80.25, 91.04, 95.13, 102.18, 
92.28, 87.24, 82.32, 89.73, 77.01, 73.06, 74.51, 68.52, 65.64, 
66.65, 60.36, 57.58, 54.92, 51.16, 46.87, 53.24, 52.99, 59.24, 
58, 62.93, 60.05, 60.79, 115.09, 125.28, 164.87, 118.48, 112.28, 
123.73, 142.95, 134.49, 129.28, 128.86, 144.28, 140.52, 144.3, 
126.07, 123.33, 127.29, 112.46, 110.46, 104.51, 110.4, 104.65, 
97.55, 91.79, 100.61, 101.44, 107.38, 111.24, 116.33, 113.75, 
108.56, 109.02, 114.2, 107.36, 98.19), upper = c(76.31, 82.9, 
88.17, 88.44, 83.69, 82.03, 77.87, 84.06, 85.11, 92.01, 88.91, 
88.86, 88.65, 86.81, 89.81, 99.53, 102.75, 100.9, 99.56, 97.89, 
100.81, 108.36, 116.05, 117.58, 116.51, 106.83, 100.9, 108.91, 
105.27, 93.69, 103.98, 100.61, 95.15, 96.8, 90.28, 86.06, 81.53, 
77.23, 76.3, 79.68, 81.95, 82.72, 81.05, 84.32, 84.93, 82.54, 
127.88, 150.07, 186.47, 192.06, 176.37, 174.29, 190.06, 204.7, 
188.58, 188.46, 195.84, 202.25, 194.74, 185.08, 175.34, 176.93, 
170.44, 157.73, 157.63, 157.99, 151.59, 141.08, 131.59, 130.42, 
138.15, 140.31, 143.42, 150.07, 147.75, 142.11, 140.09, 144.33, 
141.45, 133.73), lower = c(30.22, 36.13, 39.01, 38.48, 34.51, 
32.6, 31.6, 36.28, 38.3, 44.67, 43.08, 39.94, 42.28, 40.71, 42.94, 
51.51, 55.67, 54.38, 54.23, 54.55, 57.46, 61.74, 71.63, 77.6, 
80.54, 69.57, 62.12, 67.59, 59.28, 42.7, 51.32, 45.09, 40.08, 
42.53, 35.77, 32.55, 27.1, 20.57, 21.06, 24.82, 28.72, 30.88, 
30.95, 35.52, 36.42, 34.22, 70.75, 85.33, 103.24, 87.58, 67.9, 
62.02, 73.05, 84.76, 71.63, 77.33, 84.64, 90.02, 89.1, 77.7, 
72.82, 75.07, 68.71, 59.23, 62.23, 64.66, 62.64, 60.89, 55.12, 
61.59, 71.59, 74.1, 79.39, 85.32, 83.95, 76.89, 76.09, 79.39, 
74.65, 73.61), xyz_expected = c(TRUE, TRUE, TRUE, TRUE, TRUE, 
TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, 
TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, 
TRUE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, 
FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, 
FALSE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, 
TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, 
TRUE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, 
FALSE, FALSE, FALSE, FALSE, FALSE)), row.names = c(NA, -80L), class = c("data.table", 
"data.frame"))

> data

      value  upper  lower xyz_expected
 1:  54.71  76.31  30.22         TRUE
 2:  62.48  82.90  36.13         TRUE
 3:  57.88  88.17  39.01         TRUE
 4:  60.64  88.44  38.48         TRUE
 5:  56.00  83.69  34.51         TRUE
 6:  54.28  82.03  32.60         TRUE
 7:  55.22  77.87  31.60         TRUE
 8:  63.77  84.06  36.28         TRUE
 9:  64.47  85.11  38.30         TRUE
10:  67.36  92.01  44.67         TRUE
11:  64.45  88.91  43.08         TRUE
12:  64.70  88.86  39.94         TRUE
13:  65.15  88.65  42.28         TRUE
14:  62.19  86.81  40.71         TRUE
15:  70.25  89.81  42.94         TRUE
16:  75.47  99.53  51.51         TRUE
17:  79.75 102.75  55.67         TRUE
18:  75.75 100.90  54.38         TRUE
19:  75.24  99.56  54.23         TRUE
20:  76.00  97.89  54.55         TRUE
21:  80.25 100.81  57.46         TRUE
22:  91.04 108.36  61.74         TRUE
23:  95.13 116.05  71.63         TRUE
24: 102.18 117.58  77.60         TRUE
25:  92.28 116.51  80.54         TRUE
26:  87.24 106.83  69.57         TRUE
27:  82.32 100.90  62.12         TRUE
28:  89.73 108.91  67.59         TRUE
29:  77.01 105.27  59.28        FALSE
30:  73.06  93.69  42.70        FALSE
31:  74.51 103.98  51.32        FALSE
32:  68.52 100.61  45.09        FALSE
33:  65.64  95.15  40.08        FALSE
34:  66.65  96.80  42.53        FALSE
35:  60.36  90.28  35.77        FALSE
36:  57.58  86.06  32.55        FALSE
37:  54.92  81.53  27.10        FALSE
38:  51.16  77.23  20.57        FALSE
39:  46.87  76.30  21.06        FALSE
40:  53.24  79.68  24.82        FALSE
41:  52.99  81.95  28.72        FALSE
42:  59.24  82.72  30.88        FALSE
43:  58.00  81.05  30.95        FALSE
44:  62.93  84.32  35.52        FALSE
45:  60.05  84.93  36.42        FALSE
46:  60.79  82.54  34.22        FALSE
47: 115.09 127.88  70.75         TRUE
48: 125.28 150.07  85.33         TRUE
49: 164.87 186.47 103.24         TRUE
50: 118.48 192.06  87.58         TRUE
51: 112.28 176.37  67.90         TRUE
52: 123.73 174.29  62.02         TRUE
53: 142.95 190.06  73.05         TRUE
54: 134.49 204.70  84.76         TRUE
55: 129.28 188.58  71.63         TRUE
56: 128.86 188.46  77.33         TRUE
57: 144.28 195.84  84.64         TRUE
58: 140.52 202.25  90.02         TRUE
59: 144.30 194.74  89.10         TRUE
60: 126.07 185.08  77.70         TRUE
61: 123.33 175.34  72.82         TRUE
62: 127.29 176.93  75.07         TRUE
63: 112.46 170.44  68.71         TRUE
64: 110.46 157.73  59.23         TRUE
65: 104.51 157.63  62.23         TRUE
66: 110.40 157.99  64.66         TRUE
67: 104.65 151.59  62.64         TRUE
68:  97.55 141.08  60.89        FALSE
69:  91.79 131.59  55.12        FALSE
70: 100.61 130.42  61.59        FALSE
71: 101.44 138.15  71.59        FALSE
72: 107.38 140.31  74.10        FALSE
73: 111.24 143.42  79.39        FALSE
74: 116.33 150.07  85.32        FALSE
75: 113.75 147.75  83.95        FALSE
76: 108.56 142.11  76.89        FALSE
77: 109.02 140.09  76.09        FALSE
78: 114.20 144.33  79.39        FALSE
79: 107.36 141.45  74.65        FALSE
80:  98.19 133.73  73.61        FALSE
     value  upper  lower xyz_expected

I want to calculate the new value of xyz based on the fcase statement given in code below. The same can be done in a lengthy for loop but I just want to use data.table.

data[, xyz := TRUE]
data[, xyz := {
  fcase(
    value > upper, TRUE,
    value < lower, FALSE,
    data.table::between(value, lower = lower, upper = upper), 
    {
      shift(xyz, 1, type = "lag")
      if(xyz == TRUE & lower < shift(lower, 1, type = "lag"))
      {
        lower = shift(lower, 1, type = "lag")
      } 
      if(xyz == FALSE & upper > shift(upper, 1, type = "lag"))
      {
        upper = shift(upper, 1, type = "lag")
      } 
    }
  )
}]

Running the above code gives me the following error -

Error in `:=`(xyz, { : 
  Check that is.data.table(DT) == TRUE. Otherwise, := and `:=`(...) are defined for use in j, once only and in particular ways. See help(":=").

I will appreciate it if someone can show me how to resolve this error, or is there no way to accomplish this task using data.table?

Update - 1 Following code works, but it creates a local copy of variable lower in the last fcase condition and the value of lower is not updated in the original data.table data during runtime.

data[, xyz := TRUE]
data[, xyz := {
  fcase(
    value > upper, TRUE,
    value < lower, FALSE,
    data.table::between(value, lower = lower, upper = upper), 
    {
      lower = ifelse(xyz == TRUE & lower < shift(lower, 1, type = "lag"), shift(lower, 1, type = "lag"), lower)
      upper = ifelse(xyz == FALSE & upper > shift(upper, 1, type = "lag"), shift(upper, 1, type = "lag"), upper)
      shift(xyz, 1, type = "lag")
    }
  )
}]

Update - 2 As suggested by @r2avens I have tried using the function Reduce. I have no idea how to use the function Reduce in this scenario. I will appreciate if someone can show me the correct way to use Reduce.

  temp_dt[, xyz := TRUE]
  temp_dt[, xyz := {
    Reduce(fcase(
      value > upper, TRUE,
      value < lower, FALSE,
      data.table::between(value, lower = lower, upper = upper), 
      {
        lower = ifelse(xyz == TRUE & lower < shift(lower, 1, type = "lag"), shift(lower, 1, type = "lag"), lower)
        upper = ifelse(xyz == FALSE & upper > shift(upper, 1, type = "lag"), shift(upper, 1, type = "lag"), upper)
        shift(xyz, 1, type = "lag")
      }
    ), accumulate = TRUE)
  }]
Error in Reduce(fcase(value > upper, TRUE, value < lower, FALSE, data.table::between(value, :
argument "x" is missing, with no default

Update - 3 Following is the working for loop -

calculate  <- function(x){
  for (current in 2:nrow(x)) {
    previous <- current - 1
    
    nxt <- ifelse(
      current == nrow(x),
      NA,
      current + 1
    )
    
    if (isTRUE(as.numeric(x$value[current]) > as.numeric(x$upper[previous]))) {
      x$xyz[current] <- TRUE
    } else if (isTRUE(as.numeric(x$value[current]) < as.numeric(x$lower[previous]))) {
      x$xyz[current] <- FALSE
    } else {
      x$xyz[current] <- x$xyz[previous]

      if (isTRUE(as.logical(x$xyz[current] == TRUE)) & isTRUE((as.numeric(x$lower[current]) < as.numeric(x$lower[previous])))) {
        x$lower[current] <- x$lower[previous]
      }
      
      if (isTRUE(as.logical(x$xyz[current] == FALSE)) & isTRUE((as.numeric(x$upper[current]) > as.numeric(x$upper[previous])))) {
        x$upper[current] <- x$upper[previous]
      }
    }
    
  }
  return(x)
}


dt1 = calculate(data)

Benchmarking - While benchmarking, statements of data.table solution do not completely match the statements in forloop solution since the exact solution using data.table is yet to be found. But it gives a very close approximation.

microbenchmark::microbenchmark(
forloop = calculate(data),
MrSmith = MrSmith(data, nrow(data)),
Alexis_cpp = Alexis_cpp(data),
datatable = {data[, xyz := TRUE]
data[, xyz := {
  fcase(
    value > upper, TRUE,
    value < lower, FALSE,
    data.table::between(value, lower = lower, upper = upper), 
    {
      lower = ifelse(xyz == TRUE & lower < shift(lower, 1, type = "lag"), shift(lower, 1, type = "lag"), lower)
      upper = ifelse(xyz == FALSE & upper > shift(upper, 1, type = "lag"), shift(upper, 1, type = "lag"), upper)
      shift(xyz, 1, type = "lag")
    }
  )
}]}
, times = 10)

Unit: microseconds
       expr        min         lq        mean      median         uq        max neval cld
    forloop  10718.595  10774.386  10953.7406  10846.0465  11120.934  11325.962    10  b 
    MrSmith 278600.405 280761.948 296481.4276 298688.6040 305259.497 331570.519    10   c
 Alexis_cpp     13.288     13.924     25.4258     27.9855     32.196     45.881    10 a  
  datatable   1656.803   1700.195   1826.1999   1764.6285   1938.074   2125.853    10 ab

Upvotes: 1

Views: 240

Answers (2)

MrSmithGoesToWashington
MrSmithGoesToWashington

Reputation: 1076

May be a recursive function could help - but it's not very efficient.

myFun <- function(tbl, n){
  if (n==1){
    tbl[n, xyz := TRUE]
  }else{
    tbl <- myFun(tbl, n-1)
    prev_upper <- tbl[n-1, upper]
    prev_lower <- tbl[n-1, lower]
    prev_xyz<- tbl[n-1, xyz]
    tbl[n, xyz := fcase(value > prev_upper, TRUE, 
                        value < prev_lower, FALSE,
                        value <= prev_upper & value >= prev_lower, prev_xyz)]
    tbl[n, lower := fcase(value <= prev_upper & value >= prev_lower
                          & lower< prev_lower
                          & xyz == TRUE, prev_lower,
                        rep_len(TRUE, length(lower)), lower)]
    tbl[n, upper := fcase(value <= prev_upper & value >= prev_lower
                          & xyz == FALSE
                          & upper > prev_upper, prev_upper,
                          rep_len(TRUE, length(upper)), upper)]
    print(n)
    
  }
  return(tbl)
}

res <- myFun(data, nrow(data) )

Upvotes: 1

Alexis
Alexis

Reputation: 5059

I don't see how you could do this with data.table even if you use Reduce because you want to modify more than 1 value iteratively. In case it may convince you to try something else, here's a C++ version with Rcpp.

Save this in foo.cpp:

// [[Rcpp::plugins(cpp11)]]
#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
SEXP foo_cpp(DataFrame df) {
    NumericVector value = df["value"];
    NumericVector lower = df["lower"];
    NumericVector upper = df["upper"];
    LogicalVector xyz = df["xyz"];

    for (auto current = 2; current < df.nrow(); current++) {
        auto previous = current - 1;

        if (value[current] > upper[previous]) {
            xyz[current] = true;
        } else if (value[current] < lower[previous]) {
            xyz[current] = false;
        } else {
            xyz[current] = xyz[previous];

            if (xyz[current] && lower[current] < lower[previous]) {
                lower[current] = lower[previous];
            }

            if (!xyz[current] && upper[current] > upper[previous]) {
                upper[current] = upper[previous];
            }
        }
    }

    return R_NilValue;
}

and in R:

library(Rcpp)
sourceCpp("foo.cpp")
foo_cpp(data)

Upvotes: 3

Related Questions