Rorschach
Rorschach

Reputation: 32426

Interactions between factors in data.table

How does one compute interactions using data.table? Specifically, I am trying to get all the unique combinations between successive groupings of columns from right to left (dropping unused levels). I am using code like this,

## Sample data
set.seed(1999)
dat <- setDT(lapply(split(letters[1:9], 1:3), function(l) factor(sample(l, 10, TRUE, (1:3)^3))))
dat
#     1 2 3
#  1: d h i
#  2: g e f
#  3: g h i
#  4: g h i
#  5: d h i
#  6: g h c
#  7: d h i
#  8: g h f
#  9: g e i
# 10: d e i

## All factor combinations from left to right by column
f <- function(...) interaction(..., drop=TRUE)
levs <- Reduce(f, dat, accumulate = TRUE)
res <- unlist(lapply(levs, levels))
#  [1] "d"     "g"     "d.e"   "g.e"   "d.h"   "g.h"   "g.h.c" "g.e.f" "g.h.f"
# [10] "d.e.i" "g.e.i" "d.h.i" "g.h.i"

where res is the intended result. It works fine, but I might as well just be using a data.frame since this isn't advantage of any of the internal data.table matching.

This is just worse cause it repeats everything.

dat[, Reduce(f, .SD, accumulate = TRUE)]

Can I replace base's interaction with a fast data.table one?

Edit

a larger example with data from gglot2

data(diamonds, package="ggplot2")
dat <- as.data.table(diamonds)
sdcols <- c("cut", "color", "clarity")  # some factor columns

## Expected output, really just interested in the levels,
## so character instead of factor is fine
levs <- unlist(Reduce(function(...) interaction(..., drop=TRUE),
                      dat[,sdcols,with=FALSE], accumulate = TRUE))
length(levels(levs))  # [1] 316

## Not quite right
levs2 <- dat[, Reduce(function(...) do.call(paste, c(list(...), sep=".")), .SD,
                      accumulate = TRUE), .SDcols=sdcols]

Upvotes: 4

Views: 1244

Answers (1)

Frank
Frank

Reputation: 66819

Using the OP's example:

data(diamonds, package="ggplot2")
dat <- as.data.table(diamonds)
sdcols <- c("cut", "color", "clarity")

DAT <- dat[, sdcols, with=FALSE]    

Here are a few options

f       <- function(...) interaction(..., drop=TRUE)
baseint <- function() lapply(Reduce(f, DAT, accumulate = TRUE), levels)

newint  <- function() lapply(seq_along(DAT), function(nj) do.call(paste, c(
  sep=".",
  unique(DAT[,seq(nj),with=FALSE])
)))

newint2  <- function(){
  DAT0 = unique(DAT)
  res  = vector("list", length(DAT))
  for (k in length(DAT):1){
    res[[k]] <- do.call(paste, c(sep=".",DAT0))
    DAT0[, (length(DAT0)) := NULL]
    DAT0 <- unique(DAT0)
  }
  res
}

library(microbenchmark)
microbenchmark(
  base = {baseres = baseint()},
  new  = {newres  = newint()},
  new2 = {newres2 = newint2()},
  times = 3
)

# Unit: milliseconds
#  expr       min        lq      mean    median        uq       max neval
#  base 14.110835 14.377433 16.910993 14.644031 18.311072 21.978113     3
#   new  3.335112  3.352311  3.680126  3.369511  3.852634  4.335756     3
#  new2  2.662375  2.843113  3.963925  3.023850  4.614700  6.205549     3

identical(lapply(baseres,sort), lapply(newres,sort))  # TRUE
identical(lapply(baseres,sort), lapply(newres2,sort)) # TRUE

The second idea for a new interaction, newint2, takes these steps:

  1. Uniquify data
  2. Paste columns
  3. Drop rightmost column
  4. Repeat from step 1 while any columns are left

Comments.

This is a very small example, and it's not clear to me what a larger one would look like (where we're talking about saving more than a few milliseconds).

One last one, if you are only need to get the length of the accumulated interactions,

dat <- as.data.table(diamonds)
setkeyv(dat, sdcols)
tst <- vector("list", length(sdcols))
for (i in 1:length(sdcols)) tst[[i]] <- uniqueN(rleidv(dat[, sdcols[1:i], with=FALSE]))

Upvotes: 1

Related Questions