jwdink
jwdink

Reputation: 4875

Pandas: Group-by and Aggregate Column 1 with Condition from Column 2

I'm trying to move from R & dplyr into python and Pandas for some projects, and I'm hoping to figure out how to replicate common coding strategies I used with dplyr.

One common one is that I'll group by a particular column, then calculate a derived column that involves a condition from some third column. Here's a simple example:

dat = data.frame(user = rep(c("1",2,3,4),each=5),
           cancel_date = rep(c(12,5,10,11), each=5)
) %>%
  group_by(user) %>%
  mutate(login = sample(1:cancel_date[1], size = n(), replace = T)) %>%
  ungroup()

-

Source: local data frame [6 x 3]

  user cancel_date login
1    1          12     3
2    1          12     9
3    1          12    12
4    1          12     4
5    1          12     2
6    2           5     4

In this data frame, I'd like to calculate how many logins each user had three months before they cancelled. In dplyr, this is simple:

dat %>%
  group_by(user) %>%
  summarise(logins_three_mos_before_cancel = length(login[cancel_date-login>=3]))

  user logins_three_mos_before_cancel
1    1                              4
2    2                              1
3    3                              5
4    4                              3

But I'm a bit stumped at how to do this pandas. As far as I can tell, aggregate only applies a function on a given grouped column, and I don't know how to get it to apply a function that involves multiple columns.

Here's that same data in pandas:

d = { 'user' : np.repeat([1,2,3,4],5),
     'cancel_date' : np.repeat([12,5,10,11],5),
     'login' : np.array([3,  9, 12,  4,  2,  4,  3,  5,  5,  1,  3,  5,  4,  6,  3,  3,  5, 10,  7, 10])
     }
pd.DataFrame(data=d)

Upvotes: 0

Views: 600

Answers (2)

Panwen Wang
Panwen Wang

Reputation: 3835

It's pretty easy to translate your R code into python with datar:

>>> from datar.all import (
...     f, c, tibble, rep, length, set_seed,
...     group_by, mutate, sample, n, ungroup, summarise, 
... )
>>> 
>>> set_seed(8525)
>>> 
>>> dat = tibble(
...     user=rep(c("1", 2, 3, 4), each=5),
...     cancel_date=rep(c(12, 5, 10, 11), each=5)
... ) >> group_by(
...     f.user
... ) >> mutate(
...     login=sample(f[1:f.cancel_date[0]], size=n(), replace=True)
... ) >> ungroup()
>>> 
>>> dat
       user  cancel_date   login
   <object>      <int64> <int64>
0         1           12       6
1         1           12      11
2         1           12       6
3         1           12       1
4         1           12       7
5         2            5       4
6         2            5       2
7         2            5       4
8         2            5       4
9         2            5       1
10        3           10       5
11        3           10       2
12        3           10       9
13        3           10      10
14        3           10       3
15        4           11      11
16        4           11       6
17        4           11      10
18        4           11       1
19        4           11       6
>>> dat >> group_by(
...     f.user
... ) >> summarise(
...     logins_three_mos_before_cancel = length(f.login[f.cancel_date-f.login>=3])
... )
      user  logins_three_mos_before_cancel
  <object>                         <int64>
0        1                               4
1        2                               2
2        3                               3
3        4                               3

Disclaimer: I am the author of the datar package.

Upvotes: 0

Ami Tavory
Ami Tavory

Reputation: 76346

I hope I followed your R, but do you mean this?

>> df[df.cancel_date - df.login >= 3].user.value_counts().sort_index()
1    4
2    1
3    5
4    3
dtype: int64

Upvotes: 2

Related Questions