Reputation: 2673
How to implement a User Defined Aggregate Function (UDAF) in PySpark SQL?
pyspark version = 3.0.2
python version = 3.7.10
As a minimal example, I'd like to replace the AVG aggregate function with a UDAF:
sc = SparkContext()
sql = SQLContext(sc)
df = sql.createDataFrame(
pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')
rv = sql.sql('SELECT id, AVG(value) FROM df GROUP BY id').toPandas()
where rv will be:
In [2]: rv
Out[2]:
id avg(value)
0 1 1.5
1 2 3.5
How can a UDAF replace AVG
in the query?
For example this does not work
import numpy as np
def udf_avg(x):
return np.mean(x)
sql.udf.register('udf_avg', udf_avg)
rv = sql.sql('SELECT id, udf_avg(value) FROM df GROUP BY id').toPandas()
The idea is to implement a UDAF in pure Python for processing not supported by SQL aggregate functions (e.g. a low-pass filter).
Upvotes: 5
Views: 3343
Reputation: 2673
A Pandas UDF can be used, where the definition is compatible from Spark 3.0
and Python 3.6+
. See the issue and documentation for details.
Full implementation in Spark SQL:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')
@pandas_udf(DoubleType())
def avg_udf(s: pd.Series) -> float:
return s.mean()
spark.udf.register('avg_udf', avg_udf)
rv = spark.sql('SELECT id, avg_udf(value) FROM df GROUP BY id').toPandas()
with return value
In [2]: rv
Out[2]:
id avg_udf(value)
0 1 1.5
1 2 3.5
Upvotes: 4
Reputation: 42342
You can use a Pandas UDF with GROUPED_AGG
type. It receives columns from Spark as Pandas Series, so that you can call Series.mean
on the column.
import pyspark.sql.functions as F
@F.pandas_udf('float', F.PandasUDFType.GROUPED_AGG)
def avg_udf(s):
return s.mean()
df2 = df.groupBy('id').agg(avg_udf('value'))
df2.show()
+---+--------------+
| id|avg_udf(value)|
+---+--------------+
| 1| 1.5|
| 2| 3.5|
+---+--------------+
To register it for use in SQL is also possible:
df.createTempView('df')
spark.udf.register('avg_udf', avg_udf)
df2 = spark.sql("select id, avg_udf(value) from df group by id")
df2.show()
+---+--------------+
| id|avg_udf(value)|
+---+--------------+
| 1| 1.5|
| 2| 3.5|
+---+--------------+
Upvotes: 2