Saurabh Sharma
Saurabh Sharma

Reputation: 325

Need to Know Partitioning Details in Dataframe Spark

I am trying to read from DB2 database on base of a query. The result set of the query is about 20 - 40 million records. The partition of the DF is done based of a column which is integer.

My question is that, once data is loaded how can I check how many records were created per partition. Basically what I want to check is if data skew is happening or not? How can I check the record counts per partition?

Upvotes: 5

Views: 10638

Answers (2)

cph_sto
cph_sto

Reputation: 7585

Let's create a DataFrame first.

rdd=sc.parallelize([('a',22),('b',1),('c',4),('b',1),('d',2),('e',0),('d',3),('a',1),('c',4),('b',7),('a',2),('f',1)] )
df=rdd.toDF(['key','value'])
df=df.repartition(5,"key") # Make 5 Partitions

The number of partitions -

print("Number of partitions: {}".format(df.rdd.getNumPartitions())) 
    Number of partitions: 5

Number of rows on each partition. This can give you an idea of skew -

print('Partitioning distribution: '+ str(df.rdd.glom().map(len).collect()))
    Partitioning distribution: [3, 3, 2, 2, 2]

See how actually are rows distributed on the partitions. Behold that if the dataset is big, then your system could crash because of Out of Memory issue.

print("Partitions structure: {}".format(df.rdd.glom().collect()))
    Partitions structure: [
       #Partition 1        [Row(key='a', value=22), Row(key='a', value=1), Row(key='a', value=2)], 
       #Partition 2        [Row(key='b', value=1), Row(key='b', value=1), Row(key='b', value=7)], 
       #Partition 3        [Row(key='c', value=4), Row(key='c', value=4)], 
       #Partition 4        [Row(key='e', value=0), Row(key='f', value=1)], 
       #Partition 5        [Row(key='d', value=2), Row(key='d', value=3)]
                          ]

Upvotes: 9

bluenote10
bluenote10

Reputation: 26530

You can for instance map over the partitions and determine their sizes:

val rdd = sc.parallelize(0 until 1000, 3)
val partitionSizes = rdd.mapPartitions(iter => Iterator(iter.length)).collect()

// would be Array(333, 333, 334) in this example

This works for both the RDD and the Dataset/DataFrame API.

Upvotes: 0

Related Questions