Miguel
Miguel

Reputation: 433

Pandas - shifting a rolling sum after grouping spills over to following groups

I might be doing something wrong, but I was trying to calculate a rolling average (let's use sum instead in this example for simplicity) after grouping the dataframe. Until here it all works well, but when I apply a shift I'm finding the values spill over to the group below. See example below:

import pandas as pd

df = pd.DataFrame({'X': ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'],
                   'Y': [1, 2, 3, 1, 2, 3, 1, 2, 3]})

grouped_df = df.groupby(by='X')['Y'].rolling(window=2, min_periods=2).sum().shift(periods=1)
print(grouped_df)

Expected result:

X   
A  0    NaN
   1    NaN
   2    3.0
B  3    NaN
   4    NaN
   5    3.0
C  6    NaN
   7    NaN
   8    3.0

Result I actually get:

X   
A  0    NaN
   1    NaN
   2    3.0
B  3    5.0
   4    NaN
   5    3.0
C  6    5.0
   7    NaN
   8    3.0

You can see the result of A2 gets passed to B3 and the result of B5 to C6. I'm not sure this is the intended behaviour and I'm doing something wrong or there is some bug in pandas?

Thanks

Upvotes: 2

Views: 908

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150785

The problem is that

df.groupby(by='X')['Y'].rolling(window=2, min_periods=2).sum()

returns a new series, then when you chain with shift(), you shift the series as a whole, not within the group.

You need another groupby to shift within the group:

grouped_df = (df.groupby(by='X')['Y'].rolling(window=2, min_periods=2).sum()
                .groupby(level=0).shift(periods=1)
             )

Or use groupby.transform:

grouped_df = (df.groupby('X')['Y']
                .transform(lambda x: x.rolling(window=2, min_periods=2)
                                      .sum().shift(periods=1))
             )

Output:

X   
A  0    NaN
   1    NaN
   2    3.0
B  3    NaN
   4    NaN
   5    3.0
C  6    NaN
   7    NaN
   8    3.0
Name: Y, dtype: float64

Upvotes: 2

Related Questions