Ben
Ben

Reputation: 21705

data.table slow aggregating on factor column

Came across this issue today. I have a data.table with some categorical fields (i.e. factors). Something like

set.seed(2016)
dt <- data.table(
  ID=factor(sample(30000, 2000000, replace=TRUE)), 
  Letter=factor(LETTERS[sample(26, 2000000, replace=TRUE)])
)

dt
      ID Letter
1:  5405      E
2:  4289      E
3: 25250      J
4:  4008      J
5: 14326      G
---

Now, I'd like to calculate the gini impurity for each column of dt, grouped by the values in ID.

My attempt:

giniImpurity <- function(vals){
  # Returns the gini impurity of a set of categorical values
  # vals can either be the raw category instances (vals=c("red", "red", "blue", "green")) or named category frequencies (vals=c(red=2, blue=1, green=1))
  # Gini Impurity is the probability a value is incorrectly labeled when labeled according to the distribution of classes in the set

  if(is(vals, "numeric")) counts <- vals else counts <- table(vals)
  total <- sum(counts)

  return(sum((counts/total)*(1-counts/total)))
}

# Calculate gini impurities
dt[, list(Samples=.N, ID.GinitImpurity=giniImpurity(ID), Letter.GiniImpurity=giniImpurity(Letter)), by=ID]
          ID Samples ID.GinitImpurity Letter.GiniImpurity
    1:  5405      66                0              0.9527
    2:  4289      73                0              0.9484
    3: 25250      60                0              0.9394
    4:  4008      66                0              0.9431
    5: 14326      79                0              0.9531
   ---                                                   

This works but it's incredibly slow. It seems that if I change ID from factor to numeric, it runs much quicker. Is this what I should do in practice or is there a less hacky way to speed up this operation?

Also, I know it's unnecessary to calculate the gini impurity of ID grouped by itself, but please look past this. My real dataset has many more categorical features which add to the slowness.

Also note that I'm using data.table version 1.9.7 (devel)


EDIT

Sorry guys... I just realized that when I tested this with ID as numeric instead of a factor, my call to giniImpurity() is where the speed up occurred due to the nature of how it works. I guess the call to table() is where the slowdown is. Still not 100% sure how to make this quicker though.

Upvotes: 2

Views: 92

Answers (1)

Ben
Ben

Reputation: 21705

Got it.

giniImpurities <- function(dt){
  # Returns pairs of categorical fields (cat1, cat2, GI) where GI is the weighted gini impurity of 
  # cat2 relative to the groups determined by cat1

  #--------------------------------------------------
  # Subset dt by just the categorical fields

  catfields <- colnames(dt)[sapply(dt, is.factor)]
  cats1 <- dt[, catfields, with=FALSE]

  # Build a table to store the results
  varpairs <- CJ(Var1=catfields, Var2=catfields)
  varpairs[Var1==Var2, GI := 0]

  # Loop through each grouping variable
  for(catcol in catfields){
    print(paste("Calculating gini impurities by field:", catcol))

    setkeyv(cats1, catcol)
    impuritiesDT <- cats1[, list(Samples=.N), keyby=catcol]

    # Looop through each of the other categorical columns
    for(colname in setdiff(catfields, catcol)){

      # Get the gini impurity for each pair (catcol, other)
      counts <- cats1[, list(.N), by=c(catcol, colname)]
      impurities <- counts[, list(GI=sum((N/sum(N))*(1-N/sum(N)))), by=catcol]
      impuritiesDT[impurities, GI := GI]
      setnames(impuritiesDT, "GI", colname)
    }

    cats1.gini <- melt(impuritiesDT, id.vars=c(catcol, "Samples"))
    cats1.gini <- cats1.gini[, list(GI=weighted.mean(x=value, w=Samples)), by=variable]
    cats1.gini <- cats1.gini[, list(Var1=catcol, Var2=variable, GI)]
    varpairs[cats1.gini, `:=`(GI=i.GI), on=c("Var1", "Var2")]
  }

  return(varpairs[])
}

giniImpurities(dt)
      Var1    Var2        GI
1:  Letter  Letter 0.0000000
2:  Letter Letter2 0.9615258
3:  Letter  PGroup 0.9999537
4: Letter2  Letter 0.9615254
5: Letter2 Letter2 0.0000000
6: Letter2  PGroup 0.9999537
7:  PGroup  Letter 0.9471393
8:  PGroup Letter2 0.9470965
9:  PGroup  PGroup 0.0000000

Upvotes: 1

Related Questions