human
human

Reputation: 2441

Why is dataset.count causing a shuffle! (spark 2.2)

Here is my dataframe:

enter image description here

The underlying RDD has 2 partitions

enter image description here enter image description here

When I do a df.count, the DAG produced is enter image description here

When I do a df.rdd.count, the DAG produced is:

enter image description here

Ques: Count is an action in spark, the official definition is ‘Returns the number of rows in the DataFrame.’. Now, when I perform the count on the dataframe why is a shuffle occurring? Besides, when I do the same on the underlying RDD no shuffle occurs.

It makes no sense to me why a shuffle would occur anyway. I tried to go through the source code of count here spark github But it doesn’t make sense to me fully. Is the “groupby” being supplied to the action the culprit?

PS. df.coalesce(1).count does not cause any shuffle

Upvotes: 17

Views: 7147

Answers (3)

Tom
Tom

Reputation: 6342

In the shuffle stage, the key is empty, and the value is count of the partition, and all these (key,value) pairs are shuffled to one single partition.

That is, the data moved in the shuffle stage is very little.

Upvotes: 3

Pratyush Sharma
Pratyush Sharma

Reputation: 289

It seems that DataFrame's count operation uses groupBy resulting in shuffle. Below is the code from https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

* Returns the number of rows in the Dataset.
* @group action
* @since 1.6.0
*/
def count(): Long = withAction("count", groupBy().count().queryExecution) { 
plan =>
plan.executeCollect().head.getLong(0)
}

While if you look at RDD's count function, it passes on the aggregate function to each of the partitions, which returns the sum of each partition as Array and then use .sum to sum elements of array.

Code snippet from this link: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/RDD.scala

/**
* Return the number of elements in the RDD.
*/
def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

Upvotes: 10

nefo_x
nefo_x

Reputation: 3088

When spark is doing dataframe operation, it does first compute partial counts for every partition and then having another stage to sum those up together. This is particularly good for large dataframes, where distributing counts to multiple executors actually adds to performance.

The place to verify this is SQL tab of Spark UI, which would have some sort of the following physical plan description:

*HashAggregate(keys=[], functions=[count(1)], output=[count#202L])
+- Exchange SinglePartition
   +- *HashAggregate(keys=[], functions=[partial_count(1)], output=[count#206L])

Upvotes: 5

Related Questions