Reputation: 21705
Came across this issue today. I have a data.table
with some categorical fields (i.e. factors). Something like
set.seed(2016)
dt <- data.table(
ID=factor(sample(30000, 2000000, replace=TRUE)),
Letter=factor(LETTERS[sample(26, 2000000, replace=TRUE)])
)
dt
ID Letter
1: 5405 E
2: 4289 E
3: 25250 J
4: 4008 J
5: 14326 G
---
Now, I'd like to calculate the gini impurity for each column of dt
, grouped by the values in ID.
My attempt:
giniImpurity <- function(vals){
# Returns the gini impurity of a set of categorical values
# vals can either be the raw category instances (vals=c("red", "red", "blue", "green")) or named category frequencies (vals=c(red=2, blue=1, green=1))
# Gini Impurity is the probability a value is incorrectly labeled when labeled according to the distribution of classes in the set
if(is(vals, "numeric")) counts <- vals else counts <- table(vals)
total <- sum(counts)
return(sum((counts/total)*(1-counts/total)))
}
# Calculate gini impurities
dt[, list(Samples=.N, ID.GinitImpurity=giniImpurity(ID), Letter.GiniImpurity=giniImpurity(Letter)), by=ID]
ID Samples ID.GinitImpurity Letter.GiniImpurity
1: 5405 66 0 0.9527
2: 4289 73 0 0.9484
3: 25250 60 0 0.9394
4: 4008 66 0 0.9431
5: 14326 79 0 0.9531
---
This works but it's incredibly slow. It seems that if I change ID from factor to numeric, it runs much quicker. Is this what I should do in practice or is there a less hacky way to speed up this operation?
Also, I know it's unnecessary to calculate the gini impurity of ID grouped by itself, but please look past this. My real dataset has many more categorical features which add to the slowness.
Also note that I'm using data.table version 1.9.7 (devel)
Sorry guys... I just realized that when I tested this with ID as numeric instead of a factor, my call to giniImpurity()
is where the speed up occurred due to the nature of how it works. I guess the call to table()
is where the slowdown is. Still not 100% sure how to make this quicker though.
Upvotes: 2
Views: 92
Reputation: 21705
Got it.
giniImpurities <- function(dt){
# Returns pairs of categorical fields (cat1, cat2, GI) where GI is the weighted gini impurity of
# cat2 relative to the groups determined by cat1
#--------------------------------------------------
# Subset dt by just the categorical fields
catfields <- colnames(dt)[sapply(dt, is.factor)]
cats1 <- dt[, catfields, with=FALSE]
# Build a table to store the results
varpairs <- CJ(Var1=catfields, Var2=catfields)
varpairs[Var1==Var2, GI := 0]
# Loop through each grouping variable
for(catcol in catfields){
print(paste("Calculating gini impurities by field:", catcol))
setkeyv(cats1, catcol)
impuritiesDT <- cats1[, list(Samples=.N), keyby=catcol]
# Looop through each of the other categorical columns
for(colname in setdiff(catfields, catcol)){
# Get the gini impurity for each pair (catcol, other)
counts <- cats1[, list(.N), by=c(catcol, colname)]
impurities <- counts[, list(GI=sum((N/sum(N))*(1-N/sum(N)))), by=catcol]
impuritiesDT[impurities, GI := GI]
setnames(impuritiesDT, "GI", colname)
}
cats1.gini <- melt(impuritiesDT, id.vars=c(catcol, "Samples"))
cats1.gini <- cats1.gini[, list(GI=weighted.mean(x=value, w=Samples)), by=variable]
cats1.gini <- cats1.gini[, list(Var1=catcol, Var2=variable, GI)]
varpairs[cats1.gini, `:=`(GI=i.GI), on=c("Var1", "Var2")]
}
return(varpairs[])
}
giniImpurities(dt)
Var1 Var2 GI
1: Letter Letter 0.0000000
2: Letter Letter2 0.9615258
3: Letter PGroup 0.9999537
4: Letter2 Letter 0.9615254
5: Letter2 Letter2 0.0000000
6: Letter2 PGroup 0.9999537
7: PGroup Letter 0.9471393
8: PGroup Letter2 0.9470965
9: PGroup PGroup 0.0000000
Upvotes: 1