Lukas Boersma
Lukas Boersma

Reputation: 1082

How to compute multiple counts with different conditions on a pyspark DataFrame, fast?

Let's say I have this pyspark Dataframe:

data = spark.createDataFrame(schema=['Country'], data=[('AT',), ('BE',), ('France',), ('Latvia',)])

And let's say I want to collect various statistics about this data. For example, I might want to know how many rows use a 2-character country code and how many use longer country names:

count_short = data.where(F.length(F.col('Country')) == 2).count()
count_long = data.where(F.length(F.col('Country')) > 2).count()

This works, but when I want to collect many different counts based on different conditions, it becomes very slow even for tiny datasets. In Azure Synapse Studio, where I am working, every count takes 1-2 seconds to compute.

I need to do 100+ counts, and it takes multiple minutes to compute for a dataset of 10 rows. And before somebody asks, the conditions for those counts are more complex than in my example. I cannot group by length or do other tricks like that.

I am looking for a general way to do multiple counts on arbitrary conditions, fast.

I am guessing that the reason for the slow performance is that for every count call, my pyspark notebook starts some Spark processes that have significant overhead. So I assume that if there was some way to collect these counts in a single query, my performance problems would be solved.

One possible solution I thought of is to build a temporary column that indicates which of my conditions have been matched, and then call countDistinct on it. But then I would have individual counts for all combinations of condition matches. I also noticed that depending on the situation, the performance is a bit better when I do data = data.localCheckpoint() before computing my statistics, but the general problem still persists.

Is there a better way?

Upvotes: 0

Views: 2395

Answers (3)

Nikunj Kakadiya
Nikunj Kakadiya

Reputation: 2988

Few things to keep in mind. If you are applying multiple actions on your dataframe and there are lot of transformations and you are reading that data from some external source then you should definitely cache that dataframe before you apply any single action on that dataframe.

The answer provided by @pasha701 works but you will have to keep on adding the columns based on different country code length value you want to analyse.

You can use the below code to get the count of different country codes all in one single dataframe.

//import statements
from pyspark.sql.functions import *
//sample Dataframe
data = spark.createDataFrame(schema=['Country'], data=[('AT',), ('ACE',), ('BE',), ('France',), ('Latvia',)])
//adding additional column that gives the length of the country codes
data1 = data.withColumn("CountryLength",length(col('Country')))
//creating columns list schema for the final output
outputcolumns = ["CountryLength","RecordsCount"]
//selecting the countrylength column and converting that to rdd and performing map reduce operation to count the occurrences of the same length 
countrieslength = data1.select("CountryLength").rdd.map(lambda word: (word, 1)).reduceByKey(lambda a,b:a +b).toDF(outputcolumns).select("CountryLength.CountryLength","RecordsCount")
//now you can do display or show on the dataframe to see the output
display(countrieslength)

please see the output snapshot that you might get as below : enter image description here

If you want to apply multiple filter condition on this dataframe, then you can cache this dataframe and get the count of different combination of records based on the country code length.

Upvotes: 0

greenie
greenie

Reputation: 444

While one way is to combine multiple queries in to one, the other way is to cache the dataframe that is being queried again and again. By caching the dataframe, we avoid the re-evaluation each time the count() is invoked.

data.cache()

Upvotes: 1

pasha701
pasha701

Reputation: 7207

Function "count" can be replaced by "sum" with condition (Scala):

data.select(
  sum(
    when(length(col("Country")) === 2, 1).otherwise(0)
  ).alias("two_characters"),
  sum(
    when(length(col("Country")) > 2, 1).otherwise(0)
  ).alias("more_than_two_characters")
)

Upvotes: 2

Related Questions