Kenan
Kenan

Reputation: 14094

Pandas groupby and transform based on multiple columns

I have seen a lot of similar questions but none seem to work for my case. I'm pretty sure this is just a groupby transform but I keep getting KeyError along with axis issues. I am trying to groupby filename and check count where pred != gt.

For example Index 2 is the only one for f1.wav so 1, and Index (13,14,18) for f2.wav so 3.

df = pd.DataFrame([{'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 2, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 2, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 2, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 2, 'gt': 0, 'filename': 'f2.wav'}])
    pred  gt filename
0      0   0   f1.wav
1      0   0   f1.wav
2      2   0   f1.wav
3      0   0   f1.wav
4      0   0   f1.wav
5      0   0   f1.wav
6      0   0   f1.wav
7      0   0   f1.wav
8      0   0   f1.wav
9      0   0   f1.wav
10     0   0   f2.wav

Expected output

    pred  gt filename  counts
0      0   0   f1.wav       1
1      0   0   f1.wav       1
2      2   0   f1.wav       1
3      0   0   f1.wav       1
4      0   0   f1.wav       1
5      0   0   f1.wav       1
6      0   0   f1.wav       1
7      0   0   f1.wav       1
8      0   0   f1.wav       1
9      0   0   f1.wav       1
10     0   0   f2.wav       3
11     0   0   f2.wav       3
12     2   2   f2.wav       3
13     0   2   f2.wav       3
14     0   2   f2.wav       3
15     0   0   f2.wav       3
16     0   0   f2.wav       3
17     2   2   f2.wav       3
18     0   2   f2.wav       3
19     2   0   f2.wav       3

I was thinking df.groupby('filename').transform(lambda x: x['pred'].ne(x['gt']).sum(), axis=1) but I get TypeError: Transform function invalid for data types

Upvotes: 3

Views: 2298

Answers (3)

Jeremy Feng
Jeremy Feng

Reputation: 365

You can aggregate data from many columns into a tuple. You can then deal with a single column which contains data from many columns.

My solution:

df["pred_gt"] = list(zip(*[df["pred"], df["gt"]]))
df["counts"] = df.groupby("filename")["pred_gt"].transform(
    lambda x: x.apply(lambda y: y[0] != y[1]).sum()
)
print(df)
    pred  gt filename pred_gt  counts
0      0   0   f1.wav  (0, 0)       1
1      0   0   f1.wav  (0, 0)       1
2      2   0   f1.wav  (2, 0)       1
3      0   0   f1.wav  (0, 0)       1
4      0   0   f1.wav  (0, 0)       1
5      0   0   f1.wav  (0, 0)       1
6      0   0   f1.wav  (0, 0)       1
7      0   0   f1.wav  (0, 0)       1
8      0   0   f1.wav  (0, 0)       1
9      0   0   f1.wav  (0, 0)       1
10     0   0   f2.wav  (0, 0)       4
11     0   0   f2.wav  (0, 0)       4
12     2   2   f2.wav  (2, 2)       4
13     0   2   f2.wav  (0, 2)       4
14     0   2   f2.wav  (0, 2)       4
15     0   0   f2.wav  (0, 0)       4
16     0   0   f2.wav  (0, 0)       4
17     2   2   f2.wav  (2, 2)       4
18     0   2   f2.wav  (0, 2)       4
19     2   0   f2.wav  (2, 0)       4

This method also works for 3 or more columns.

Upvotes: 0

Gonçalo Peres
Gonçalo Peres

Reputation: 13582

Considering that df is the dataframe OP shares in the question, in order to groupby filename and check count where pred != gt, one can use pandas.DataFrame.groupby and pandas.DataFrame.apply as follows

df2 = df.groupby('filename').apply(lambda x: x[x['pred'] != x['gt']])

[Out]:
             pred  gt filename
filename                      
f1.wav   2      2   0   f1.wav
f2.wav   13     0   2   f2.wav
         14     0   2   f2.wav
         18     0   2   f2.wav
         19     2   0   f2.wav

Assuming one wants to count the number of occurrences for each filename, as, after the previous operation, filename is both an index level and a column label, which is ambiguous, and considering that OP wants to have a column named count to count the number of each item in each group, one will have to groupby level (one of the various parameters one can pass), and, finally, use pandas.core.groupby.GroupBy.cumcount. (Note: As opposed to the accepted answer, this approach will count sequentially)

df2['count'] = df2.groupby(level=0).cumcount() + 1 # The +1 is to make the count start at 1 instead of 0.

[Out]:
             pred  gt filename  count
filename                             
f1.wav   2      2   0   f1.wav      1
f2.wav   13     0   2   f2.wav      1
         14     0   2   f2.wav      2
         18     0   2   f2.wav      3
         19     2   0   f2.wav      4

A one-liner would look like the following

df2['count'] = df.groupby('filename').apply(lambda x: x[x['pred'] != x['gt']]).groupby(level=0).cumcount() + 1

[Out]:
             pred  gt filename  count
filename                             
f1.wav   2      2   0   f1.wav      1
f2.wav   13     0   2   f2.wav      1
         14     0   2   f2.wav      2
         18     0   2   f2.wav      3
         19     2   0   f2.wav      4

If having the count in a separate column is not a requirement, considering df2 as the dataframe after the first operation mentioned in this answer (when df2 was created), then one can simply use the following (which gives a more high-level overview)

df3 = df2.groupby(level=0).count().iloc[:, 0]

[Out]:
filename
f1.wav    1
f2.wav    4
Name: pred, dtype: int64

Upvotes: 1

Cameron Riddell
Cameron Riddell

Reputation: 13407

.transform operates on each column individually, so you won't be able to access both 'pred' and 'gt' in a transform operation.

This leaves you with 2 options:

  1. aggregate and reindex or join back to the original shape
  2. pre-compute the boolean array and .transform on that

approach 2 will probably be the fastest here:

df['counts'] = (
    (df['pred'] != df['gt'])
    .groupby(df['filename']).transform('sum')
)

print(df)
    pred  gt filename  counts
0      0   0   f1.wav       1
1      0   0   f1.wav       1
2      2   0   f1.wav       1
3      0   0   f1.wav       1
4      0   0   f1.wav       1
5      0   0   f1.wav       1
6      0   0   f1.wav       1
7      0   0   f1.wav       1
8      0   0   f1.wav       1
9      0   0   f1.wav       1
10     0   0   f2.wav       4
11     0   0   f2.wav       4
12     2   2   f2.wav       4
13     0   2   f2.wav       4
14     0   2   f2.wav       4
15     0   0   f2.wav       4
16     0   0   f2.wav       4
17     2   2   f2.wav       4
18     0   2   f2.wav       4
19     2   0   f2.wav       4

Note that f2.wav has 4 instances where 'pre' != 'gt' (index 13, 14, 18, 19)

Upvotes: 7

Related Questions