Adrian
Adrian

Reputation: 3308

Clean way to compute transition probabilities between two columns of a data.table in R

Toy example:

library(data.table)

set.seed(1)
n_people <- 100
groups <- c("A", "B", "C")
example_table <- data.table(person_id=seq_len(n_people),
                            group_2010=sample(groups, n_people, TRUE),
                            group_2011=sample(groups, n_people, TRUE))

## Error-prone and requires lots of typing -- programmatic alternative?
transition_probs <- example_table[, list(pr_A_2011=mean(group_2011=="A"),
                                         pr_B_2011=mean(group_2011=="B"),
                                         pr_C_2011=mean(group_2011=="C")),
                                         by=group_2010]
transition_probs  # Essentially a transition matrix giving Pr[group_2011 | group_2010]

#    group_2010 pr_A_2011 pr_B_2011 pr_C_2011
# 1:          A 0.1481481 0.5185185 0.3333333
# 2:          B 0.3684211 0.3947368 0.2368421
# 3:          C 0.3142857 0.3142857 0.3714286

The "manual" approach above is fine when the groups are A, B, C, but gets messy if there are more groups (or if we just have the groups vector but don't know ahead of time what it contains).

Is there a "data.table way" to compute the transition_probs object in my example code above? Can list(pr_A_2011=...) be replaced with something programmatic?

My concern is that, if I add a group D, I will have to edit the code in multiple places, notably by typing pr_D_2011=mean(group_2011=="D").

Upvotes: 3

Views: 111

Answers (3)

jangorecki
jangorecki

Reputation: 16697

I see both current answers very well addressing your question. I will answer then handling it in a more generic manner.
If you want real programmatic power you can use computing on the language R language feature.

R belongs to a class of programming languages in which subroutines have the ability to modify or construct other subroutines and evaluate the result as an integral part of the language itself.

library(data.table)
set.seed(1)
n_people <- 100
groups <- c("A", "B", "C")
example_table <- data.table(person_id=seq_len(n_people),
                            group_2010=sample(groups, n_people, TRUE),
                            group_2011=sample(groups, n_people, TRUE))
f = function(data, groups, years) {
    stopifnot(is.data.table(data), length(groups) > 0L, length(years) == 2L, paste0("group_", years) %in% names(data))
    j.names = sprintf("pr_%s_%s", c(groups), years[2L])
    j.vals = lapply(setNames(groups, j.names), function(group) call("mean", call("==", as.name(sprintf("group_%s", years[2L])), group)))
    jj = as.call(c(list(as.name(".")), j.vals))
    data[, eval(jj), by = c(sprintf("group_%s", years[1L]))]
}
f(example_table, groups, 2010:2011)
#   group_2010 pr_A_2011 pr_B_2011 pr_C_2011
#1:          A 0.1481481 0.5185185 0.3333333
#2:          B 0.3684211 0.3947368 0.2368421
#3:          C 0.3142857 0.3142857 0.3714286

No need to replace code in few places, just passing arguments to function.

Upvotes: 1

A. Webb
A. Webb

Reputation: 26446

The design of data.table is intentionally meant to be compatible with operations on data.frames, so unless you can (a) prove this operation is a huge bottleneck and (b) demonstrate that alternate solutions are significantly faster, why not stick with brevity and clarity:

prop.table(table(example_table[,2:3,with=FALSE]),1)
          group_2011
group_2010         A         B         C
         A 0.1481481 0.5185185 0.3333333
         B 0.3684211 0.3947368 0.2368421
         C 0.3142857 0.3142857 0.3714286

Upvotes: 2

Frank
Frank

Reputation: 66819

I would do

lvls = example_table[, sort(unique(c(group_2010, group_2011))) ]
x = dcast(example_table, group_2010~group_2011)[, N := Reduce(`+`,.SD), .SDcols=lvls]

#    group_2010  A  B  C  N
# 1:          A  6  9 15 30
# 2:          B 15  4 12 31
# 3:          C 11 11 17 39

From here, if you want transition probabilities, just divide by N:

x[, (lvls) := lapply(.SD,`/`, x$N), .SDcols=lvls]
# or, with data.table 1.9.7+
x[, (lvls) := lapply(.SD,`/`, N), .SDcols=lvls]

#    group_2010         A         B         C  N
# 1:          A 0.1481481 0.5185185 0.3333333 27
# 2:          B 0.3684211 0.3947368 0.2368421 38
# 3:          C 0.3142857 0.3142857 0.3714286 35

Upvotes: 3

Related Questions