Mudyla
Mudyla

Reputation: 277

Using groupby, shift and rolling in Pandas

I am trying to calculate rolling averages within groups. For this task I want a rolling average from the rows above so thought the easiest way would be to use shift() and then do rolling(). The problem is that shift() shifts the data from previous groups which makes first row in group 2 and 3 incorrect. Column 'ma' should have NaN in rows 4 and 7. How can I achieve this?

import pandas as pd

df = pd.DataFrame(
    {"Group": [1, 2, 3, 1, 2, 3, 1, 2, 3],
     "Value": [2.5, 2.9, 1.6, 9.1, 5.7, 8.2, 4.9, 3.1, 7.5]
     })

df = df.sort_values(['Group'])
df.reset_index(inplace=True)

df['ma'] = df.groupby('Group', as_index=False)['Value'].shift(1).rolling(3, min_periods=1).mean()

print(df)

I get this:

   index  Group  Value    ma
0      0      1    2.5   NaN
1      3      1    9.1  2.50
2      6      1    4.9  5.80
3      1      2    2.9  5.80
4      4      2    5.7  6.00
5      7      2    3.1  4.30
6      2      3    1.6  4.30
7      5      3    8.2  3.65
8      8      3    7.5  4.90

I tried answers from couple similar questions but nothing seems to work.

Upvotes: 3

Views: 4266

Answers (1)

Yati Raj
Yati Raj

Reputation: 448

If I understand the question correctly, then the solution you require can be achieved in 2 steps using the following:

df['sa'] = df.groupby('Group', as_index=False)['Value'].transform(lambda x: x.shift(1))

df['ma'] = df.groupby('Group', as_index=False)['sa'].transform(lambda x: x.rolling(3, min_periods=1).mean())

I got the below output, where 'ma' is the desired column

index   Group   Value   sa  ma
0   0   1   2.5     NaN     NaN
1   3   1   9.1     2.5     2.5
2   6   1   4.9     9.1     5.8
3   1   2   2.9     NaN     NaN
4   4   2   5.7     2.9     2.9
5   7   2   3.1     5.7     4.3
6   2   3   1.6     NaN     NaN
7   5   3   8.2     1.6     1.6
8   8   3   7.5     8.2     4.9

Edit: Example with one groupby

def shift_ma(x):
    return x.shift(1).rolling(3, min_periods=1).mean()

df['ma'] = df.groupby('Group', as_index=False)['Value'].apply(shift_ma).reset_index(drop=True)

Upvotes: 3

Related Questions