Reputation: 11
I'm trying to write a recursive SQL query (DuckDB) to calculate a cumulative sum with a reset condition. The reset condition in the example provided below is when val >= 30.
I have done this successfully (see, Code 1, provided below).
However, I would like to also solve this problem in a more efficent way (in DuckDB) as the code takes a while to run on thousands of records. Perhaps Code 1 is already (near) optimal and I'm wasting my time?
Code 2 is my attempt at what I believe to be a faster approach, however, it doesn't work (I think it runs into an infinite loop).
The reason I think it is a faster approach (if it worked) is because it attempts to:
I would expect that, in the worst case, this to have the same run time as Code 1 (if it needed to recursively iterate per row as Code 1 does). All other cases the run time would be something faster, proportional to how many times it needs to iterate given satisfying the threshold condition.
It isn't obvious to me what is causing the issue.
My attempts to debug focused on isolating various parts of Code 2. Each part seemed to produce what one would expect in isolation but not work as intended on the whole.
Any help is appreciated.
library(data.table)
library(duckdb)
ref <- data.table(
idx = c(1:20)
,val = c(1,2,5,5,5,15,15,15,31,3,25,10,29,0,0,9,4,8,8,10)
)
con <- dbConnect(duckdb())
dbWriteTable(con, "ref", ref, overwrite=T)
### Code 1
code_1 <- dbGetQuery(con, paste0("
WITH
RECURSIVE cte as (
SELECT
idx
,val
,val AS c_val
FROM
ref
WHERE
idx = 1
UNION ALL
SELECT
ref.idx
,ref.val
,(CASE WHEN ref.val + cte.c_val >= 30 + ref.val THEN ref.val
ELSE ref.val + cte.c_val
END) AS c_val
FROM
cte
JOIN
ref ON ref.idx = cte.idx + 1
)
SELECT * FROM cte
"))
### Code 2
code_2 <- dbGetQuery(con, paste0("
WITH
RECURSIVE cte as (
SELECT
0 AS idx
,0 AS val
,0 AS c_val
UNION ALL
SELECT
idx
,val
,c_val
FROM
(
SELECT
*
FROM
(
SELECT
*
,SUM(val) OVER (ORDER BY idx) AS c_val
,c_val-val AS c_val_adj
FROM
(
SELECT
idx
,val
FROM
ref
WHERE
idx NOT IN (SELECT idx FROM cte)
)
)
WHERE
c_val_adj <= 30
)
)
SELECT * FROM cte
"))
For reference, here is what Code 1 generates (which is what I want Code 2 to generate as well): enter image description here
Upvotes: 1
Views: 167
Reputation: 270045
1) Below we get a 3x speedup on the data shown in the question by registering the SQL table in R, performing the accumulation in R and then registering it back to the database. The input is shown in the Note at the end. Run that first.
library(data.table)
library(duckdb)
library(microbenchmark)
con <- dbConnect(duckdb())
duckdb_register(con, "ref", ref, overwrite = TRUE)
library(microbenchmark)
microbenchmark(times = 10,
SQL = code_1 <- dbGetQuery(con, "
WITH RECURSIVE cte as (
SELECT idx, val, val AS c_val FROM ref WHERE idx = 1
UNION ALL
SELECT
ref.idx,
ref.val,
(CASE WHEN ref.val + cte.c_val >= 30 + ref.val
THEN ref.val
ELSE ref.val + cte.c_val END) AS c_val
FROM cte
JOIN ref ON ref.idx = cte.idx + 1
)
SELECT * FROM cte
"),
R = {
refDF <- dbGetQuery(con, "SELECT * FROM ref")
f3 <- function(x, y) if (x >= 30) y else x + y
refDF$cum <- Reduce(f3, refDF$val, init = 0, acc = TRUE)[-1]
duckdb_register(con, "refDF", refDF, overwrite = TRUE)
})
giving
## Unit: milliseconds
## expr min lq mean median uq max neval cld
## SQL 14.9846 15.2263 25.40338 16.52835 22.7394 91.8009 10 a
## R 5.7175 5.7452 8.60860 5.82635 5.9205 33.7428 10 b
identical(refDF$cum, expected)
## [1] TRUE
2) This alternative uses the duckdb cli and gawk rather than R. Ensure you have both installed. RTools has gawk for Windows and on Linux normally gawk is included.
Run the code in the Note at the end in R and then run this in R creating the ref.duckdb database with the input table, ref. This is just to set up the test framework and is not part of the actual processing.
library(duckdb)
con <- dbConnect(duckdb(), "ref.duckdb")
dbWriteTable(con, "ref", ref)
dbDisconnect(con)
Now define this gawk file called ref.awk using a text editor:
BEGIN { print "idx,val,cum" }
NR == 1 { next }
{ if (prev >= 30) prev = $2; else prev = $2 + prev; print $0 "," prev }
From the system command line (not from R) run this. (If you are on Windows and gawk is not on your PATH you may need to refer to it as something like \Rtools44\usr\bin\gawk. ) The first line copies the ref table in ref.duckdb to a csv file, the next line does the actual processing using gawk creating ref2.csv. Note that the -k option tells it that the input is a csv file and is available on gawk but maybe not other versions of awk. The last line uploads ref2.csv back to the ref.duckdb database.
duckdb ref.duckdb "copy ref to 'ref.csv'"
gawk -k -f ref.awk ref.csv > ref2.csv
duckdb ref.duckdb "create table ref2 as from read_csv_auto('ref2.csv')"
library(data.table)
ref <- data.table(
idx = 1:20,
val = c(1,2,5,5,5,15,15,15,31,3,25,10,29,0,0,9,4,8,8,10)
)
expected <- c(1, 3, 8, 13, 18, 33, 15, 30, 31, 3, 28, 38, 29, 29, 29, 38,
4, 12, 20, 30)
Upvotes: 2
Reputation: 5916
See another example.
This example is somewhat closer to your Code2. The longer the period falls into one group, the more profitable this request is.
Recursion depth=(row count)/(avg group size).
In your code depth=(row count).
There recursive query takes row ranges, where cum.sum for range<30 and next 1 row.
Recursive query output is
lvl | ifrom | ito | val | csum | dsum |
---|---|---|---|---|---|
1 | 1 | 6 | 15 | 33 | 0 |
2 | 7 | 8 | 15 | 63 | 33 |
3 | 9 | 9 | 31 | 94 | 63 |
4 | 10 | 12 | 10 | 132 | 94 |
5 | 13 | 16 | 9 | 170 | 132 |
6 | 17 | 20 | 10 | 200 | 170 |
join with source table and calculate running total for every group as row running sum - dsum.
-- explain analyze
with recursive t as(
select * ,sum(val)over(order by idx) csum,sum(val)over()totsum
,max(idx)over()maxidx
from ref
)
,r as(
(select 1 lvl,1 as iFrom,idx as iTo ,val, csum,cast(0 as bigint) as dsum
from t
where t.csum>=30 and (t.csum-val)<30
order by idx limit 1)
union all
select lvl+1,r.iTo+1 as iFrom,t.idx as iTo,t.val, t.csum,r.csum as dsum
from r inner join t on t.idx>r.iTo
and (
((t.csum-r.csum)>=30 and (t.csum-r.csum-t.val)<30)
or ((totsum-r.csum)<30 and t.idx=maxidx ) -- for the last (incomplete) group
)
)
select t.* -- ,r.*
,sum(t.val)over(order by idx)-dsum as grSum
,r.ifrom,r.ito
from ref t
inner join r on t.idx between r.iFrom and r.iTo
;
Index on(idx) is important if the data was added to the table in random order. Then sorting or searching would take considerable time.
Upvotes: 0