Mysterious
Mysterious

Reputation: 881

Pyspark - groupby with filter - Optimizing speed

I have billions of rows to process using Pyspark.

Dataframe looks like this:

category    value    flag
   A          10       1
   A          12       0
   B          15       0
and so on...

I need to run two groupby operations: one on rows where flag==1 and other for ALL rows. Currently I am doing this:

frame_1 = df.filter(df.flag==1).groupBy('category').agg(F.sum('value').alias('foo1'))
frame_2 = df.groupBy('category').agg(F.sum('value').alias(foo2))
final_frame = frame1.join(frame2,on='category',how='left')

As of now this code is running, but my problem is that it is very slow. Is there a way to improve this code in terms of speed or this is the limit because I understand lazy evaluation by PySpark does take some time, but is this code the best way of doing this?

Upvotes: 1

Views: 3071

Answers (2)

pault
pault

Reputation: 43494

IIUC, you can avoid the expensive join and achieve this using one groupBy.

final_frame_2 = df.groupBy("category").agg(
    F.sum(F.col("value")*F.col("flag")).alias("foo1"),
    F.sum(F.col("value")).alias("foo2"),
)
final_frame_2.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       B| 0.0|15.0|
#|       A|10.0|22.0|
#+--------+----+----+

Now compare the execution plans:

First your method:

final_frame.explain()
#== Physical Plan ==
#*(5) Project [category#0, foo1#68, foo2#75]
#+- SortMergeJoin [category#0], [category#78], LeftOuter
#   :- *(2) Sort [category#0 ASC NULLS FIRST], false, 0
#   :  +- *(2) HashAggregate(keys=[category#0], functions=[sum(cast(value#1 as double))])
#   :     +- Exchange hashpartitioning(category#0, 200)
#   :        +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum(cast(value#1 as double))])
#   :           +- *(1) Project [category#0, value#1]
#   :              +- *(1) Filter (isnotnull(flag#2) && (cast(flag#2 as int) = 1))
#   :                 +- Scan ExistingRDD[category#0,value#1,flag#2]
#   +- *(4) Sort [category#78 ASC NULLS FIRST], false, 0
#      +- *(4) HashAggregate(keys=[category#78], functions=[sum(cast(value#79 as double))])
#         +- Exchange hashpartitioning(category#78, 200)
#            +- *(3) HashAggregate(keys=[category#78], functions=[partial_sum(cast(value#79 as double))])
#               +- *(3) Project [category#78, value#79]
#                  +- Scan ExistingRDD[category#78,value#79,flag#80]

Now the same for final_frame_2:

final_frame_2.explain()
#== Physical Plan ==
#*(2) HashAggregate(keys=[category#0], functions=[sum((cast(value#1 as double) * cast(flag#2 as double))), sum(cast(value#1 as double))])
#+- Exchange hashpartitioning(category#0, 200)
#   +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum((cast(value#1 as double) * cast(flag#2 as double))), partial_sum(cast(value#1 as double))])
#      +- Scan ExistingRDD[category#0,value#1,flag#2]

Note: Strictly speaking, this is not the exact same output as the example you gave (shown below) because your inner join will eliminate all categories where there is no row with flag = 1.

final_frame.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       A|10.0|22.0|
#+--------+----+----+

You can add an aggregation to sum flag and filter those where the sum is zero if that's a requirement, with only a minor hit to performance.

final_frame_3 = df.groupBy("category").agg(
    F.sum(F.col("value")*F.col("flag")).alias("foo1"),
    F.sum(F.col("value")).alias("foo2"),
    F.sum(F.col("flag")).alias("foo3")
).where(F.col("foo3")!=0).drop("foo3")

final_frame_3.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#|       A|10.0|22.0|
#+--------+----+----+

Upvotes: 2

pissall
pissall

Reputation: 7399

Please note that a join operation is expensive. You can just do this and add flag to your groups:

frame_1 = df.groupBy(["category", "flag"]).agg(F.sum('value').alias('foo1'))

if you have more than two flags and you want to do flag == 1 vs the rest then:

import pyspark.sql.functions as F
frame_1 = df.withColumn("flag2", F.when(F.col("flag") == 1, 1).otherwise(0))
frame_1 = df.groupBy(["category", "flag2"]).agg(F.sum('value').alias('foo1'))

if you want to do a groupby apply for all rows, just make a new frame where you do another roll up for category:

frame_1 = df.groupBy("category").agg(F.sum('foo1').alias('foo2'))

it is not possible to do both in one step, because essentially there is a group overlap.

Upvotes: 0

Related Questions