ohwaitwhatohok
ohwaitwhatohok

Reputation: 11

Recursive SQL attempt going into infinite loop

I'm trying to write a recursive SQL query (DuckDB) to calculate a cumulative sum with a reset condition. The reset condition in the example provided below is when val >= 30.

I have done this successfully (see, Code 1, provided below).

However, I would like to also solve this problem in a more efficent way (in DuckDB) as the code takes a while to run on thousands of records. Perhaps Code 1 is already (near) optimal and I'm wasting my time?

Code 2 is my attempt at what I believe to be a faster approach, however, it doesn't work (I think it runs into an infinite loop).

The reason I think it is a faster approach (if it worked) is because it attempts to:

  1. filter the reference dataset to exclude rows that are unioned to the cte table at each iteration.
  2. then, on these remaining rows, calculate a cumulative sum and filter the result to just keep the records where the sum has exceeded the threshold.
  3. union all to the cte and repeat from step 1 above.

I would expect that, in the worst case, this to have the same run time as Code 1 (if it needed to recursively iterate per row as Code 1 does). All other cases the run time would be something faster, proportional to how many times it needs to iterate given satisfying the threshold condition.

It isn't obvious to me what is causing the issue.

My attempts to debug focused on isolating various parts of Code 2. Each part seemed to produce what one would expect in isolation but not work as intended on the whole.

Any help is appreciated.

library(data.table)
library(duckdb)

ref <- data.table(
  idx  = c(1:20)
  ,val = c(1,2,5,5,5,15,15,15,31,3,25,10,29,0,0,9,4,8,8,10)
)

con <- dbConnect(duckdb())

dbWriteTable(con, "ref", ref, overwrite=T)


### Code 1

code_1 <- dbGetQuery(con, paste0("
    WITH 
      RECURSIVE cte as (
                          SELECT
                            idx
                            ,val
                            ,val AS c_val
                          FROM
                            ref
                          WHERE
                            idx = 1
    
                          UNION ALL
    
                          SELECT
                            ref.idx
                            ,ref.val
                            ,(CASE WHEN ref.val + cte.c_val >= 30 + ref.val THEN ref.val
                                                                            ELSE ref.val + cte.c_val
                              END) AS c_val
                          FROM
                            cte
                            JOIN
                            ref ON ref.idx = cte.idx + 1
                       )
      SELECT * FROM cte
 "))

### Code 2

code_2 <- dbGetQuery(con, paste0("
    WITH 
      RECURSIVE cte as (
        SELECT
           0 AS idx
          ,0 AS val
          ,0 AS c_val

        UNION ALL

        SELECT
           idx
          ,val
          ,c_val
        FROM
          (
            SELECT
              *
            FROM
              (
                SELECT
                  * 
                  ,SUM(val) OVER (ORDER BY idx) AS c_val
                  ,c_val-val                    AS c_val_adj
                FROM
                  (
                    SELECT
                       idx
                      ,val
                    FROM
                      ref
                    WHERE
                      idx NOT IN (SELECT idx FROM cte)
                  )
              )
            WHERE
              c_val_adj <= 30
          )
     )
                       
  SELECT * FROM cte
 "))

For reference, here is what Code 1 generates (which is what I want Code 2 to generate as well): enter image description here

Upvotes: 1

Views: 167

Answers (2)

G. Grothendieck
G. Grothendieck

Reputation: 270045

1) Below we get a 3x speedup on the data shown in the question by registering the SQL table in R, performing the accumulation in R and then registering it back to the database. The input is shown in the Note at the end. Run that first.

library(data.table)
library(duckdb)
library(microbenchmark)

con <- dbConnect(duckdb())
duckdb_register(con, "ref", ref, overwrite = TRUE)

library(microbenchmark)
microbenchmark(times = 10,
SQL = code_1 <- dbGetQuery(con, "
  WITH RECURSIVE cte as (
    SELECT idx, val, val AS c_val FROM ref WHERE idx = 1
    UNION ALL
    SELECT 
      ref.idx, 
      ref.val,
      (CASE WHEN ref.val + cte.c_val >= 30 + ref.val
       THEN ref.val 
       ELSE ref.val + cte.c_val END) AS c_val
    FROM cte
    JOIN ref ON ref.idx = cte.idx + 1
  )
  SELECT * FROM cte
"),
R = {
  refDF <- dbGetQuery(con, "SELECT * FROM ref")
  f3 <- function(x, y) if (x >= 30) y else x + y
  refDF$cum <- Reduce(f3, refDF$val, init = 0, acc = TRUE)[-1]
  duckdb_register(con, "refDF", refDF, overwrite = TRUE)
})

giving

## Unit: milliseconds
##  expr     min      lq     mean   median      uq     max neval cld
##   SQL 14.9846 15.2263 25.40338 16.52835 22.7394 91.8009    10  a 
##     R  5.7175  5.7452  8.60860  5.82635  5.9205 33.7428    10   b

identical(refDF$cum, expected)
## [1] TRUE

2) This alternative uses the duckdb cli and gawk rather than R. Ensure you have both installed. RTools has gawk for Windows and on Linux normally gawk is included.

Run the code in the Note at the end in R and then run this in R creating the ref.duckdb database with the input table, ref. This is just to set up the test framework and is not part of the actual processing.

library(duckdb)
con <- dbConnect(duckdb(), "ref.duckdb")
dbWriteTable(con, "ref", ref)
dbDisconnect(con)

Now define this gawk file called ref.awk using a text editor:

BEGIN { print "idx,val,cum" }
NR == 1 { next }
{ if (prev >= 30) prev = $2; else prev = $2 + prev; print $0 "," prev }

From the system command line (not from R) run this. (If you are on Windows and gawk is not on your PATH you may need to refer to it as something like \Rtools44\usr\bin\gawk. ) The first line copies the ref table in ref.duckdb to a csv file, the next line does the actual processing using gawk creating ref2.csv. Note that the -k option tells it that the input is a csv file and is available on gawk but maybe not other versions of awk. The last line uploads ref2.csv back to the ref.duckdb database.

duckdb ref.duckdb "copy ref to 'ref.csv'"
gawk -k -f ref.awk ref.csv > ref2.csv
duckdb ref.duckdb "create table ref2 as from read_csv_auto('ref2.csv')"

Note

library(data.table)

ref <- data.table(
  idx  = 1:20, 
  val = c(1,2,5,5,5,15,15,15,31,3,25,10,29,0,0,9,4,8,8,10)
)

expected <- c(1, 3, 8, 13, 18, 33, 15, 30, 31, 3, 28, 38, 29, 29, 29, 38, 
  4, 12, 20, 30)

Upvotes: 2

ValNik
ValNik

Reputation: 5916

See another example.
This example is somewhat closer to your Code2. The longer the period falls into one group, the more profitable this request is.
Recursion depth=(row count)/(avg group size).
In your code depth=(row count).

There recursive query takes row ranges, where cum.sum for range<30 and next 1 row.
Recursive query output is

lvl ifrom ito val csum dsum
1 1 6 15 33 0
2 7 8 15 63 33
3 9 9 31 94 63
4 10 12 10 132 94
5 13 16 9 170 132
6 17 20 10 200 170

join with source table and calculate running total for every group as row running sum - dsum.

-- explain analyze 
with recursive t as(
  select * ,sum(val)over(order by idx) csum,sum(val)over()totsum
    ,max(idx)over()maxidx
  from ref
)
,r as(
   (select 1 lvl,1 as iFrom,idx as iTo ,val, csum,cast(0 as bigint) as dsum
   from t
   where t.csum>=30 and (t.csum-val)<30
   order by idx limit 1)
  union all
   select  lvl+1,r.iTo+1 as iFrom,t.idx as iTo,t.val, t.csum,r.csum as dsum
   from r inner join t on t.idx>r.iTo 
       and (
          ((t.csum-r.csum)>=30 and (t.csum-r.csum-t.val)<30)
       or ((totsum-r.csum)<30 and t.idx=maxidx ) -- for the last (incomplete) group
      )
)
select t.* -- ,r.*
  ,sum(t.val)over(order by idx)-dsum as grSum
  ,r.ifrom,r.ito 
from ref t
inner join  r on  t.idx between r.iFrom and r.iTo
;

Test fiddle (PostgreSql)

Index on(idx) is important if the data was added to the table in random order. Then sorting or searching would take considerable time.

Upvotes: 0

Related Questions