mik1904
mik1904

Reputation: 1365

PySpark aggregate operation that sum all rows in a DataFrame column of type MapType(*, IntegerType())

Suppose you create a Spark DataFrame with a precise schema:

import pyspark.sql.functions as sf
from pyspark.sql.types import *

dfschema = StructType([
    StructField("_1", ArrayType(IntegerType())),
    StructField("_2", ArrayType(IntegerType())),
])
df = spark.createDataFrame([[[1, 2, 5], [13, 74, 1]], 
                            [[1, 2, 3], [77, 23, 15]]
                           
                           ], schema=dfschema)
df = df.select(sf.map_from_arrays("_1", "_2").alias("omap"))
df = df.withColumn("id", sf.lit(1))

The above DataFrame looks like this:

+---------------------------+---+
|omap                       |id |
+---------------------------+---+
|{1 -> 13, 2 -> 74, 5 -> 1} |1  |
|{1 -> 77, 2 -> 23, 3 -> 15}|1  |
+---------------------------+---+

I would like to perform the following operation:

df.groupby("id").agg(sum_counter("omap")).show(truncate=False)

Could you please help me in defining a sum_counter function which uses only SQL functions from pyspark.sql.functions (so no UDFs) that allows me to obtain in output such a DataFrame:

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

I could solve this using applyInPandas:

from pyspark.sql.types import *
from collections import Counter
import pandas as pd

reschema = StructType([
    StructField("id", LongType()),
    StructField("mapsum", MapType(IntegerType(), IntegerType()))
])

def sum_counter(key: int, pdf: pd.DataFrame) -> pd.DataFrame:
    return pd.DataFrame([
        key
        + (sum([Counter(x) for x in pdf["omap"]], Counter()), )
    ])

df.groupby("id").applyInPandas(sum_counter, reschema).show(truncate=False)

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

However, for performance reasons, I would like to avoid using applyInPandas or UDFs. Any ideas?

Upvotes: 1

Views: 522

Answers (2)

mik1904
mik1904

Reputation: 1365

In the end I solved it like this:

import pyspark.sql.functions as sf

def sum_counter(mapcoln: str):
    dkeys = sf.array_distinct(sf.flatten(sf.collect_list(sf.map_keys(mapcoln))))
    dkeyscount = sf.transform(
        dkeys,
        lambda ukey: sf.aggregate(
            sf.collect_list(mapcoln),
            sf.lit(0),
            lambda acc, mapentry: sf.when(
                ~sf.isnull(sf.element_at(mapentry, ukey)),
                acc + sf.element_at(mapentry, ukey),
            ).otherwise(acc),
        ),
    )
    return sf.map_from_arrays(dkeys, dkeyscount).alias("mapsum")

df.groupby("id").agg(sum_counter("omap")).show(truncate=False)


+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

Upvotes: 0

Bartosz Gajda
Bartosz Gajda

Reputation: 1167

You can first explode the omap to individual rows, where key and value will be set in separate columns, and then aggregate them like so:

exploded_df = df.select("*", sf.explode("omap"))
agg_df = exploded_df.groupBy("id", "key").sum("value")
agg_df.groupBy("id").agg(sf.map_from_entries(sf.collect_list(sf.struct("key","sum(value)"))).alias("mapsum")).show(truncate=False)

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{2 -> 97, 1 -> 90, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

Upvotes: 2

Related Questions