Russell Burdt
Russell Burdt

Reputation: 2673

User Defined Aggregate Function in PySpark SQL

How to implement a User Defined Aggregate Function (UDAF) in PySpark SQL?

pyspark version = 3.0.2
python version = 3.7.10

As a minimal example, I'd like to replace the AVG aggregate function with a UDAF:

sc = SparkContext()
sql = SQLContext(sc)
df = sql.createDataFrame(
    pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')
rv = sql.sql('SELECT id, AVG(value) FROM df GROUP BY id').toPandas()

where rv will be:

In [2]: rv
Out[2]:
   id  avg(value)
0   1         1.5
1   2         3.5

How can a UDAF replace AVG in the query?

For example this does not work

import numpy as np
def udf_avg(x):
    return np.mean(x)
sql.udf.register('udf_avg', udf_avg)
rv = sql.sql('SELECT id, udf_avg(value) FROM df GROUP BY id').toPandas()

The idea is to implement a UDAF in pure Python for processing not supported by SQL aggregate functions (e.g. a low-pass filter).

Upvotes: 5

Views: 3343

Answers (2)

Russell Burdt
Russell Burdt

Reputation: 2673

A Pandas UDF can be used, where the definition is compatible from Spark 3.0 and Python 3.6+. See the issue and documentation for details.

Full implementation in Spark SQL:

import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(
    pd.DataFrame({'id': [1, 1, 2, 2], 'value': [1, 2, 3, 4]}))
df.createTempView('df')

@pandas_udf(DoubleType())
def avg_udf(s: pd.Series) -> float:
    return s.mean()
spark.udf.register('avg_udf', avg_udf)

rv = spark.sql('SELECT id, avg_udf(value) FROM df GROUP BY id').toPandas()

with return value

In [2]: rv
Out[2]:
   id  avg_udf(value)
0   1             1.5
1   2             3.5

Upvotes: 4

mck
mck

Reputation: 42342

You can use a Pandas UDF with GROUPED_AGG type. It receives columns from Spark as Pandas Series, so that you can call Series.mean on the column.

import pyspark.sql.functions as F

@F.pandas_udf('float', F.PandasUDFType.GROUPED_AGG)  
def avg_udf(s):
    return s.mean()

df2 = df.groupBy('id').agg(avg_udf('value'))

df2.show()
+---+--------------+
| id|avg_udf(value)|
+---+--------------+
|  1|           1.5|
|  2|           3.5|
+---+--------------+

To register it for use in SQL is also possible:

df.createTempView('df')
spark.udf.register('avg_udf', avg_udf)

df2 = spark.sql("select id, avg_udf(value) from df group by id")
df2.show()
+---+--------------+
| id|avg_udf(value)|
+---+--------------+
|  1|           1.5|
|  2|           3.5|
+---+--------------+

Upvotes: 2

Related Questions