geher
geher

Reputation: 495

Sampling with groups in pandas

I have a dataframe with user and session columns, I want to randomly sample the number of sessions so that the dataframe contains N unique sessions per user. The ordering of sessions is important i.e. the 'in' column per session must be preserved.

For example if N=2 and I have:

        x      in            session_id    user_id
0     0.0     1.0     trn-04a23351-283d       paul
1    -1.0     2.0     trn-04a23351-283d       paul
2    -1.0     3.0     trn-04a23351-283d       paul
3    -1.0     4.0     trn-04a23351-283d       paul
4    -1.0     1.0      blz-412313we-333       paul
5    -1.0     2.0      blz-412313we-333       paul
6     0.0     3.0      blz-412313we-333       paul
7    -1.0     1.0        wha-111111-fff       paul
8     0.0     2.0        wha-111111-fff       paul
9     1.0     1.0         bz-0000-01101      chris
10    0.0     2.0         bz-0000-01101      chris
11   -1.0     1.0       1111-sawas-1221      chris
12   -1.0     2.0       1111-sawas-1221      chris
13    1.0     1.0      pppppppppppppppp      chris
14    1.0     2.0      pppppppppppppppp      chris
15    1.0     3.0      pppppppppppppppp      chris
16   -1.0     1.0     55555555555555555     philip
17   -1.0     2.0     55555555555555555     philip
18   -1.0     3.0     55555555555555555     philip
19   -1.0     1.0       333333333333333     philip
20   -1.0     2.0       333333333333333     philip
21   -1.0     3.0       333333333333333     philip
22    0.0     1.0          zz-222222222     philip
23   -1.0     2.0          zz-222222222     philip
24    0.0     1.0       f-32355261-ss3d      sarah
25   -1.0     2.0       f-32355261-ss3d      sarah
26    0.0     3.0       f-32355261-ss3d      sarah
27    0.0     1.0               adasdfs      sarah
28   -1.0     2.0               adasdfs      sarah
29    0.0     3.0               adasdfs      sarah

I want:

        x      in            session_id    user_id
0     0.0     1.0     trn-04a23351-283d       paul
1    -1.0     2.0     trn-04a23351-283d       paul
2    -1.0     3.0     trn-04a23351-283d       paul
3    -1.0     4.0     trn-04a23351-283d       paul
4    -1.0     1.0      blz-412313we-333       paul
5    -1.0     2.0      blz-412313we-333       paul
6     0.0     3.0      blz-412313we-333       paul
7     1.0     1.0         bz-0000-01101      chris
8     0.0     2.0         bz-0000-01101      chris
9     1.0     1.0      pppppppppppppppp      chris
10    1.0     2.0      pppppppppppppppp      chris
11    1.0     3.0      pppppppppppppppp      chris
12   -1.0     1.0       333333333333333     philip
13   -1.0     2.0       333333333333333     philip
14   -1.0     3.0       333333333333333     philip
15    0.0     1.0          zz-222222222     philip
16   -1.0     2.0          zz-222222222     philip
17    0.0     1.0       f-32355261-ss3d      sarah
18   -1.0     2.0       f-32355261-ss3d      sarah
19    0.0     3.0       f-32355261-ss3d      sarah
20    0.0     1.0               adasdfs      sarah
21   -1.0     2.0               adasdfs      sarah
22    0.0     3.0               adasdfs      sarah

Upvotes: 2

Views: 66

Answers (2)

piRSquared
piRSquared

Reputation: 294278

Create a reference dataframe to merge with

d = df[['session_id', 'user_id']].drop_duplicates()
d = d.groupby('user_id', as_index=False).apply(pd.DataFrame.sample, n=2)

df.merge(d)

      x   in        session_id user_id
0  -1.0  1.0  blz-412313we-333    paul
1  -1.0  2.0  blz-412313we-333    paul
2   0.0  3.0  blz-412313we-333    paul
3  -1.0  1.0    wha-111111-fff    paul
4   0.0  2.0    wha-111111-fff    paul
5   1.0  1.0     bz-0000-01101   chris
6   0.0  2.0     bz-0000-01101   chris
7  -1.0  1.0   1111-sawas-1221   chris
8  -1.0  2.0   1111-sawas-1221   chris
9  -1.0  1.0   333333333333333  philip
10 -1.0  2.0   333333333333333  philip
11 -1.0  3.0   333333333333333  philip
12  0.0  1.0      zz-222222222  philip
13 -1.0  2.0      zz-222222222  philip
14  0.0  1.0   f-32355261-ss3d   sarah
15 -1.0  2.0   f-32355261-ss3d   sarah
16  0.0  3.0   f-32355261-ss3d   sarah
17  0.0  1.0           adasdfs   sarah
18 -1.0  2.0           adasdfs   sarah
19  0.0  3.0           adasdfs   sarah

Upvotes: 4

ALollz
ALollz

Reputation: 59549

Use groupby + transform to define a mask of your original dataframe, and then subset the original df by that mask.

I used list(set(x)) to guarantee that the same session_id is not picked twice (along with replace=False). This assumes you want each session_id to have an equal probability of showing up, regardless of how many times it appeared in the original df.

import pandas as pd
import numpy as np

np.random.seed(123)
mask = df.groupby('user_id').session_id.transform(
           lambda x: x.isin(np.random.choice(list(set(x)), 2, replace=False)))

df[mask] Outputs:

      x   in         session_id user_id
0   0.0  1.0  trn-04a23351-283d    paul
1  -1.0  2.0  trn-04a23351-283d    paul
2  -1.0  3.0  trn-04a23351-283d    paul
3  -1.0  4.0  trn-04a23351-283d    paul
7  -1.0  1.0     wha-111111-fff    paul
8   0.0  2.0     wha-111111-fff    paul
11 -1.0  1.0    1111-sawas-1221   chris
12 -1.0  2.0    1111-sawas-1221   chris
13  1.0  1.0   pppppppppppppppp   chris
14  1.0  2.0   pppppppppppppppp   chris
15  1.0  3.0   pppppppppppppppp   chris
16 -1.0  1.0  55555555555555555  philip
17 -1.0  2.0  55555555555555555  philip
18 -1.0  3.0  55555555555555555  philip
22  0.0  1.0       zz-222222222  philip
23 -1.0  2.0       zz-222222222  philip
24  0.0  1.0    f-32355261-ss3d   sarah
25 -1.0  2.0    f-32355261-ss3d   sarah
26  0.0  3.0    f-32355261-ss3d   sarah
27  0.0  1.0            adasdfs   sarah
28 -1.0  2.0            adasdfs   sarah
29  0.0  3.0            adasdfs   sarah

Upvotes: 1

Related Questions