Yoan B. M.Sc
Yoan B. M.Sc

Reputation: 1505

Spark : Control partitioning to reduce shuffle

I'm trying to wrap my head around the different ways of partitioning a dataframe in Spark in order to reduce the amount of shuffling on a specific pipeline.

Here is the dataframe I'm working on, it contains 4+ billions rows and 80 columns:

+-----+-------------------+-----------+
|  msn|          timestamp| Flight_Id |
+-----+-------------------+-----------+
|50020|2020-08-22 19:16:00|       72.0|
|50020|2020-08-22 19:15:00|       84.0|
|50020|2020-08-22 19:14:00|       96.0|
|50020|2020-08-22 19:13:00|       84.0|
|50020|2020-08-22 19:12:00|       84.0|
|50020|2020-08-22 19:11:00|       84.0|
|50020|2020-08-22 19:10:00|       84.0|
|50020|2020-08-22 19:09:00|       84.0|
|50020|2020-08-22 19:08:00|       84.0|
|50020|2020-08-22 19:07:00|       84.0|
|50020|2020-08-22 19:06:00|       84.0|
|50020|2020-08-22 19:05:00|       84.0|
|50020|2020-08-22 19:04:00|       84.0|
|50020|2020-08-22 19:03:00|       84.0|
|50020|2020-08-22 19:02:00|       84.0|
|50020|2020-08-22 19:01:00|       84.0|
|50020|2020-08-22 19:00:00|       84.0|
|50020|2020-08-22 18:59:00|       84.0|
|50020|2020-08-22 18:58:00|       84.0|
|50020|2020-08-22 18:57:00|       84.0|
+-----+-------------------+-----------+

This represent a collection of time series for different aircraft (41 aircrafts in total). I'm only doing two things on this data :

  1. Filter to keep the last 30 min of each flight using a window partitioned by MSN and Flight_ID and using an order By by timestamp.
  2. On the remaining columns, compute mean and stdev and normalize the data.

I have 32 executors with 12g of memory each and the job has crashed after running for 30 hours with the following message :

The driver running the job crashed, ran out of memory, or otherwise became unresponsive while it was running.

Looking at the Query Plan I noticed I have over 300 steps, more than 60 of them involving a shuffling (all steps Physical plan looks exactly the same):

AdaptiveSparkPlan(isFinalPlan=false)
+- CollectLimit 1
   +- HashAggregate(keys=[], functions=[avg(3546001_421#213), stddev_samp(3546001_421#213)], output=[avg(3546001_421)#10408, stddev_samp(3546001_421)#10417])
      +- Exchange SinglePartition, true
         +- HashAggregate(keys=[], functions=[partial_avg(3546001_421#213), partial_stddev_samp(3546001_421#213)], output=[sum#10479, count#10480L, n#10423, avg#10424, m2#10425])
            +- Project [3546001_421#213]
               +- Filter (isnotnull(rank#10238) && (rank#10238 <= 1800))
                  +- Window [rank(timestamp#10081) windowspecdefinition(Flight_Id_Int#209, timestamp#10081 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rank#10238], [Flight_Id_Int#209], [timestamp#10081 DESC NULLS LAST]
                     +- Sort [Flight_Id_Int#209 ASC NULLS FIRST, timestamp#10081 DESC NULLS LAST], false, 0
                        +- ShuffleQueryStage 0
                           +- Exchange hashpartitioning(Flight_Id_Int#209, 200), true
                              +- Union
                                 :- *(1) Project [Flight_Id_Int#209, cast((cast(timestamp#212L as double) / 1.0E9) as timestamp) AS timestamp#10081, 3546001_421#213]

I have a strong feeling that partitioning first by msn would help in reducing the amount of shuffling since most of the work is at the msn level.

My question are How and Where in my code I should repartition ? Should I use repartition, repartition with a key, hash partitioning I've been reading documentation on this different Partitioner and I'm confuse about their use and if that's actually the solution to my problem.

Thank you.

EDIT 1:

Here is the logical plan :

GlobalLimit 1
+- LocalLimit 1
   +- Aggregate [avg(3566000_421#214) AS avg(3566000_421)#10594, stddev_samp(3566000_421#214) AS stddev_samp(3566000_421)#10603]
      +- Project [3566000_421#214]
         +- Filter (isnotnull(rank#10238) && (rank#10238 <= 1800))
            +- Window [rank(timestamp#10081) windowspecdefinition(msn#208, Flight_Id_Int#209, timestamp#10081 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rank#10238], [msn#208, Flight_Id_Int#209], [timestamp#10081 DESC NULLS LAST]
               +- Union
                  :- Project [msn#208, Flight_Id_Int#209, cast((cast(timestamp#212L as double) / 1.0E9) as timestamp) AS timestamp#10081, 3566000_421#214]

Here is the section of the code where I pull the data together from the data lake they are stored in. FYI, this is done through a custom API using a library called FoundryTS. The important thing is that nothing in term of data is colleced until I call the to_dataframe() method. I'm looping over each msn to avoid making a call too big and then I merge all the dataframe together with unionByName

# Loop over MSN to extract timeseries
        df = []
        for msn in msn_range:
            search_results = (SeriesMetadata.M_REPORT_NUMBER == report_number) & (SeriesMetadata.M_AIRCRAFT == msn)

            # Create the intervals to split TimeSeries extract by flight for each MSN
            Start_int = list(df1.where(F.col("msn") == msn).select("Start").toPandas()["Start"])
            End_int = list(df1.where(F.col("msn") == msn).select("End").toPandas()["End"])
            flight_id = list(df1.where(F.col("msn") == msn).select("id_cmsReport").toPandas()["id_cmsReport"])

            flights_interval = [Interval(
                start, end, name=flight_Id
                ) for start, end, flight_Id in zip(
                Start_int, End_int, flight_id
                )]

            """ Collect all the series in a node collections """
            output = fts.search.series(
                search_results,
                object_types=["export-control-us-ear99-a220-dal-airline-series"])\
                .map_by(FF.interpolate(
                    before='nearest',
                    internal='nearest',
                    after='nearest',
                    frequency=frequency,
                    rename_columns_by=lambda x: x.metadata["parameter_id"] + "_" + x.metadata["report_number"]),
                    keys='msn') \
                .map_intervals(flights_interval, interval_name='Flight_Id_Int')\
                .map(FF.time_range(period_start, period_end))\
                .to_dataframe()  # !!!!  numPartitions=32  Foundry Doc : #partition = #executors see if it triggers OOM error

            df.append(output)

        output = df[0]
        for df in df[1:]:
            output = output.unionByName(df)  # Same as union but matches name instead of columns order.

        # Repartition by msn to improve latter calculation
        N = len(msn_range)
        output.repartition(N, 'msn')

Upvotes: 1

Views: 1757

Answers (2)

Yoan B. M.Sc
Yoan B. M.Sc

Reputation: 1505

For those it might help,

Here are the things I got wrong in the partitioning :

  1. .to_dataframe() : by default in our cloud platform Spark create 200 partitions. So by looping over the 40 msn I was generating 40 x 200 partition. I ended up with a lot of small task to manage.
  2. .repartition() : Since I was using a Window and partitionBy on msn I though re partitioning using msn will speed up this step. But it introduced a full shuffle of my partitions.

results: 59 GB of shuffle write according to Spark Job Tracker and > 55k tasks. Task taking some overhead this would explain the driver crashing.

What I did to make it work :

  1. I got rid of the Window function

By filtering earlier in the process before I fetch the data from the DataLake. I directly extracted for the part of the flight I needed. As a consequence, less Exchange Partition in the Physical plan for the exact same section.

Here is the updated Physical plan:

AdaptiveSparkPlan(isFinalPlan=false)
+- CollectLimit 1
   +- HashAggregate(keys=[], functions=[avg(3565000_421#213), stddev_samp(3565000_421#213)], output=[avg(3565000_421)#10246, stddev_samp(3565000_421)#10255])
      +- ShuffleQueryStage 0
         +- Exchange SinglePartition, true
            +- *(43) HashAggregate(keys=[], functions=[partial_avg(3565000_421#213), partial_stddev_samp(3565000_421#213)], output=[sum#10317, count#10318L, n#10261, avg#10262, m2#10263])
               +- Union
                  :- *(1) Project [3565000_421#213]
                  :  +- *(1) Scan ExistingRDD[msn#208,Flight_Id_Int#209,Flight_Id_Int.start#210L
  1. I decreased the amount of partition:

I set it arbitrarily to 5 in the .to_dataframe() call for each of the 40 msn.

The build succeeded after 24h. 1.1MB of Shuffle write and >27 Tasks.

As @Andrew Long pointed out, it's probably sub-optimal due to the for loop. I'm still generating at least 200 partitions for 32 executors, ending up with > 27k Tasks to manage.

Finally:

As a last step I got rid of the for loop by relying on the underlying API to fetch the data in 1 big dataframe instead and a forced the partition to 32. Build is still running as of this writing and I will edit the post with the result. But there is indeed way less Tasks to manage (by a factor 4).

EDIT 1 - update

Happy to report that by getting rid of the for loop and partitioning the dataframe in 64 partitions (32 executors x 2 cores) I was able to pull the same job in only 11h (instead of 24h) with 1.9MB of Shuffle write and 5k Tasks only.

PS: I mentioned 32 (and not 64) partitions above, but the job did not succeed with 32 and had sub-optimal parallelism (< 20) so it was tooking longer and I had idle executors. 64 in my case seems to be the sweet spot.

Upvotes: 1

Andrew Long
Andrew Long

Reputation: 933

"The driver running the job crashed, ran out of memory, or otherwise became unresponsive while it was running."

The first problem you need to fix is to bump up the memory of the driver (not the executors. The default driver mem in spark is oftentimes so low it will crash on many queries.

"My question are How and Where in my code I should repartition"

Spark already does the job of adding repartitions as necessary. Chances are you will only create extra work by manually repartitioning the data halfway through execution. One potential optimization is to store the data in a bucketed table but that will only potentially remove the first exchange and only if your bucketing column exactly matches the hash partitioning of the first exchange.

"Looking at the Query Plan I noticed I have over 300 steps"

What you described above does not take 300 steps. Something seems off here. What does your optimized logical plan look like? mean and std should only require a scan -> partitial agg -> exchange -> final agg. In the query plan you provided it looks like you're intentionally only looking at the last 1600 datapoints instead of the last 30m. Did you mean todo a window function and not a simple aggregate(aka group by)?

EDIT:

for msn in msn_range:

IMO this might be a part of your problem. This for loop causes the execution plan to be very large which may be why you're geting OOM issues on the driver. you might be able to translate this into something thats a bit more spark friendly and doesn't do as much work on the driver converting that forloop into a spark.paralellize(...).map(/your code/)

Upvotes: 1

Related Questions