Reputation: 433
I might be doing something wrong, but I was trying to calculate a rolling average (let's use sum instead in this example for simplicity) after grouping the dataframe. Until here it all works well, but when I apply a shift I'm finding the values spill over to the group below. See example below:
import pandas as pd
df = pd.DataFrame({'X': ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'],
'Y': [1, 2, 3, 1, 2, 3, 1, 2, 3]})
grouped_df = df.groupby(by='X')['Y'].rolling(window=2, min_periods=2).sum().shift(periods=1)
print(grouped_df)
Expected result:
X
A 0 NaN
1 NaN
2 3.0
B 3 NaN
4 NaN
5 3.0
C 6 NaN
7 NaN
8 3.0
Result I actually get:
X
A 0 NaN
1 NaN
2 3.0
B 3 5.0
4 NaN
5 3.0
C 6 5.0
7 NaN
8 3.0
You can see the result of A2 gets passed to B3 and the result of B5 to C6. I'm not sure this is the intended behaviour and I'm doing something wrong or there is some bug in pandas?
Thanks
Upvotes: 2
Views: 908
Reputation: 150785
The problem is that
df.groupby(by='X')['Y'].rolling(window=2, min_periods=2).sum()
returns a new series, then when you chain with shift()
, you shift the series as a whole, not within the group.
You need another groupby
to shift within the group:
grouped_df = (df.groupby(by='X')['Y'].rolling(window=2, min_periods=2).sum()
.groupby(level=0).shift(periods=1)
)
Or use groupby.transform
:
grouped_df = (df.groupby('X')['Y']
.transform(lambda x: x.rolling(window=2, min_periods=2)
.sum().shift(periods=1))
)
Output:
X
A 0 NaN
1 NaN
2 3.0
B 3 NaN
4 NaN
5 3.0
C 6 NaN
7 NaN
8 3.0
Name: Y, dtype: float64
Upvotes: 2