Abhishek Patil
Abhishek Patil

Reputation: 165

Get sum of columns from a dataframe including map column - PySpark

I have a PySpark dataframe which looks like this, I have a map datatype column Map<Str,Int>

Date              Item (Map<Str,int>)                           Total Items      ColA

2021-02-01    Item_A -> 3, Item_B -> 10, Item_C -> 2                 15            10
2021-02-02    Item_A -> 1, Item_D -> 5, Item_E ->  7                 13            20
2021-02-03    Item_A -> 8, Item_E -> 3, Item_C ->  1                 12            30

I want to sum of all the columns including the map column. For map column the sum should be calculated based on keys.

I want something like this:

[[Item_A -> 12, Item_B -> 10, Item_C -> 3, Item_D -> 5, Item_E -> 10], 40, 60]

Not necessarily a list of lists, but I want the sum of the columns.

My approach:

df.rdd.map(lambda x: (1,x[1])).reduceByKey(lambda x,y: x + y).collect()[0][1]

Upvotes: 2

Views: 513

Answers (1)

ZygD
ZygD

Reputation: 24386

You can do aggregations for map column and for other columns separately, as you would need an explode on Items column, and then other columns which you need to sum would become hard to deal with.

Example dataframe:

from pyspark.sql import functions as F

df = spark.createDataFrame(
    [('2021-02-01', {'Item_A': 3, 'Item_B': 10, 'Item_C': 2}, 15, 10),
     ('2021-02-02', {'Item_A': 1, 'Item_D': 5, 'Item_E': 7}, 13, 20),
     ('2021-02-03', {'Item_A': 8, 'Item_E': 3, 'Item_C': 1}, 12, 30)],
    ['Date', 'Item', 'Total Items', 'ColA'])
df.show(truncate=0)
# +----------+----------------------------------------+-----------+----+
# |Date      |Item                                    |Total Items|ColA|
# +----------+----------------------------------------+-----------+----+
# |2021-02-01|{Item_C -> 2, Item_B -> 10, Item_A -> 3}|15         |10  |
# |2021-02-02|{Item_E -> 7, Item_D -> 5, Item_A -> 1} |13         |20  |
# |2021-02-03|{Item_E -> 3, Item_C -> 1, Item_A -> 8} |12         |30  |
# +----------+----------------------------------------+-----------+----+

Script:

aggs = df.agg(F.sum('Total Items'), F.sum('ColA')).head()

df = (df
    .select('*', F.explode('Item'))
    .groupBy('key')
    .agg(F.sum('value').alias('value'))
    .select(
        F.map_from_entries(F.collect_set(F.struct('key', 'value'))).alias('Item'),
        F.lit(aggs[0]).alias('Total Items'),
        F.lit(aggs[1]).alias('ColA'),
    )
)
df.show(truncate=0)
# +--------------------------------------------------------------------+-----------+----+
# |Item                                                                |Total Items|ColA|
# +--------------------------------------------------------------------+-----------+----+
# |{Item_C -> 3, Item_E -> 10, Item_A -> 12, Item_B -> 10, Item_D -> 5}|40         |60  |
# +--------------------------------------------------------------------+-----------+----+

Upvotes: 2

Related Questions