Reputation: 7161
I am breaking my head trying to figure out how to use group_by
and apply a custom function using Polars.
Coming from Pandas, I was using:
import pandas as pd
from scipy.stats import spearmanr
def get_score(df):
return spearmanr(df["prediction"], df["target"]).correlation
df = pd.DataFrame({
"era": [1, 1, 1, 2, 2, 2, 5],
"prediction": [2, 4, 5, 190, 1, 4, 1],
"target": [1, 3, 2, 1, 43, 3, 1]
})
correlations = df.groupby("era").apply(get_score)
Polars has map_groups()
to apply a custom function over groups, which I tried:
correlations = df.group_by("era").map_groups(get_score)
But this fails with the error message:
'Could not get DataFrame attribute '_df'. Make sure that you return a DataFrame object.: PyErr { type: <class 'AttributeError'>, value: AttributeError("'float' object has no attribute '_df'"), traceback: None }
Any ideas?
Upvotes: 22
Views: 21627
Reputation: 14710
Polars has the pl.corr()
function which supports method="spearman"
If you want to use a custom function you could do it like this:
import polars as pl
from typing import List
from scipy import stats
df = pl.DataFrame({
"g": [1, 1, 1, 2, 2, 2, 5],
"a": [2, 4, 5, 190, 1, 4, 1],
"b": [1, 3, 2, 1, 43, 3, 1]
})
def get_score(args: List[pl.Series]) -> pl.Series:
return pl.Series([stats.spearmanr(args[0], args[1]).correlation], dtype=pl.Float64)
(df.group_by("g", maintain_order=True)
.agg(
pl.map_groups(
exprs=["a", "b"],
function=get_score).alias("corr")
))
(df.group_by("g", maintain_order=True)
.agg(
pl.corr("a", "b", method="spearman").alias("corr")
))
Both output:
shape: (3, 2)
┌─────┬──────┐
│ g ┆ corr │
│ --- ┆ --- │
│ i64 ┆ f64 │
╞═════╪══════╡
│ 1 ┆ 0.5 │
│ 2 ┆ -1.0 │
│ 5 ┆ NaN │
└─────┴──────┘
We can also apply custom functions on single expressions, via .map_elements
Below is an example of how we can square a column with a custom function and with normal polars expressions. The expression syntax should always be preferred, as its a lot faster.
(df.group_by("g")
.agg(
pl.col("a").map_elements(lambda group: group**2).alias("squared1"),
(pl.col("a")**2).alias("squared2")
))
Upvotes: 34