Reputation: 1365
Suppose you create a Spark DataFrame with a precise schema:
import pyspark.sql.functions as sf
from pyspark.sql.types import *
dfschema = StructType([
StructField("_1", ArrayType(IntegerType())),
StructField("_2", ArrayType(IntegerType())),
])
df = spark.createDataFrame([[[1, 2, 5], [13, 74, 1]],
[[1, 2, 3], [77, 23, 15]]
], schema=dfschema)
df = df.select(sf.map_from_arrays("_1", "_2").alias("omap"))
df = df.withColumn("id", sf.lit(1))
The above DataFrame looks like this:
+---------------------------+---+
|omap |id |
+---------------------------+---+
|{1 -> 13, 2 -> 74, 5 -> 1} |1 |
|{1 -> 77, 2 -> 23, 3 -> 15}|1 |
+---------------------------+---+
I would like to perform the following operation:
df.groupby("id").agg(sum_counter("omap")).show(truncate=False)
Could you please help me in defining a sum_counter
function which uses only SQL functions from pyspark.sql.functions
(so no UDFs) that allows me to obtain in output such a DataFrame:
+---+-----------------------------------+
|id |mapsum |
+---+-----------------------------------+
|1 |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+
I could solve this using applyInPandas:
from pyspark.sql.types import *
from collections import Counter
import pandas as pd
reschema = StructType([
StructField("id", LongType()),
StructField("mapsum", MapType(IntegerType(), IntegerType()))
])
def sum_counter(key: int, pdf: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame([
key
+ (sum([Counter(x) for x in pdf["omap"]], Counter()), )
])
df.groupby("id").applyInPandas(sum_counter, reschema).show(truncate=False)
+---+-----------------------------------+
|id |mapsum |
+---+-----------------------------------+
|1 |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+
However, for performance reasons, I would like to avoid using applyInPandas
or UDFs
. Any ideas?
Upvotes: 1
Views: 522
Reputation: 1365
In the end I solved it like this:
import pyspark.sql.functions as sf
def sum_counter(mapcoln: str):
dkeys = sf.array_distinct(sf.flatten(sf.collect_list(sf.map_keys(mapcoln))))
dkeyscount = sf.transform(
dkeys,
lambda ukey: sf.aggregate(
sf.collect_list(mapcoln),
sf.lit(0),
lambda acc, mapentry: sf.when(
~sf.isnull(sf.element_at(mapentry, ukey)),
acc + sf.element_at(mapentry, ukey),
).otherwise(acc),
),
)
return sf.map_from_arrays(dkeys, dkeyscount).alias("mapsum")
df.groupby("id").agg(sum_counter("omap")).show(truncate=False)
+---+-----------------------------------+
|id |mapsum |
+---+-----------------------------------+
|1 |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+
Upvotes: 0
Reputation: 1167
You can first explode the omap
to individual rows, where key and value will be set in separate columns, and then aggregate them like so:
exploded_df = df.select("*", sf.explode("omap"))
agg_df = exploded_df.groupBy("id", "key").sum("value")
agg_df.groupBy("id").agg(sf.map_from_entries(sf.collect_list(sf.struct("key","sum(value)"))).alias("mapsum")).show(truncate=False)
+---+-----------------------------------+
|id |mapsum |
+---+-----------------------------------+
|1 |{2 -> 97, 1 -> 90, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+
Upvotes: 2