Reputation: 85
I have a pandas dataframe like this
group cat
1 0
2 0
1 0
1 1
2 0
2 1
1 2
1 2
I'm trying to group the data by group
, then applies a custom function to the past 5 rows.
The custom function looks like this
def unalikeability(data):
num_observations = data.shape[0]
counts = data.value_counts()
return 1 - ((counts / num_observations)**2).sum()
Desired output:
group unalikeability
1 result calculated by the function
1
1
1
2
2
2
2
I can get the past 5 rows using groupby().rolling()
, but the rolling
object in pandas doesn't have the shape
/ value_counts
attribute and method like a DataFrame. I tried creating a DataFrame
from the rolling object, but this isn't allowed either.
Upvotes: 0
Views: 124
Reputation: 2152
import pandas as pd
import numpy as np
def unalikeability(data):
num_observations = data.shape[0]
counts = data.value_counts()
return 1 - ((counts / num_observations) ** 2).sum()
df = pd.DataFrame({'group': [1, 2, 1, 1, 2, 2, 1, 1], 'cat': [0, 0, 0, 1, 0, 1, 2, 2]})
print(df)
window_size = 5
def g(df):
# Reverse and reset index
shifted_data = df.cat.iloc[::-1].reset_index(drop=True)
df['unalikeability'] = [unalikeability(shifted_data.iloc[max(0, i- window_size +1) : i +1])
for i in range(len(df))]
return df
out = df.groupby('group').apply(g).reset_index(drop=True)
print(out)
"""
group cat unalikeability
0 1 0 0.000000
1 1 0 0.000000
2 1 1 0.444444
3 1 2 0.625000
4 1 2 0.640000
5 2 0 0.000000
6 2 0 0.500000
7 2 1 0.444444
"""
Upvotes: 0
Reputation: 262224
You can apply
your function. Depending on whether you want the output to be computed only on full chunks (5 values), or chunks of any size, use min_periods
:
def unalikeability(data):
num_observations = data.shape[0]
counts = data.value_counts()
return 1 - ((counts / num_observations)**2).sum()
# compute the score only if we have 5 rows
df['out1'] = (df.groupby('group')
.rolling(5)['cat']
.apply(unalikeability)
.droplevel('group')
)
# compute the score with incomplete chunks
df['out2'] = (df.groupby('group')
.rolling(5, min_periods=1)['cat']
.apply(unalikeability)
.droplevel('group')
)
Output:
group cat out1 out2
0 1 0 NaN 0.000000
1 2 0 NaN 0.000000
2 1 0 NaN 0.000000
3 1 1 NaN 0.444444
4 2 0 NaN 0.000000
5 2 1 NaN 0.444444
6 1 2 NaN 0.625000
7 1 2 0.64 0.640000
Upvotes: 0