newleaf
newleaf

Reputation: 2457

PySpark count rows on condition

I have a dataframe

test = spark.createDataFrame([('bn', 12452, 221), ('mb', 14521, 330), ('bn', 2, 220), ('mb', 14520, 331)], ['x', 'y', 'z'])
test.show()
# +---+-----+---+
# |  x|    y|  z|
# +---+-----+---+
# | bn|12452|221|
# | mb|14521|330|
# | bn|    2|220|
# | mb|14520|331|
# +---+-----+---+

I need to count the rows based on a condition:

test.groupBy("x").agg(count(col("y") > 12453), count(col("z") > 230)).show()

which gives

+---+------------------+----------------+
|  x|count((y > 12453))|count((z > 230))|
+---+------------------+----------------+
| bn|                 2|               2|
| mb|                 2|               2|
+---+------------------+----------------+

It's just the count of the rows, not the count for certain conditions.

Upvotes: 39

Views: 103100

Answers (5)

ZygD
ZygD

Reputation: 24478

Spark 3.5+ has count_if in Python API:

from pyspark.sql import functions as F

test.groupBy('x').agg(
    F.count_if(F.col('y') > 12453).alias('y_cnt'),
    F.count_if(F.col('z') > 230).alias('z_cnt')
).show()
# +---+-----+-----+
# |  x|y_cnt|z_cnt|
# +---+-----+-----+
# | bn|    0|    0|
# | mb|    2|    2|
# +---+-----+-----+

Spark 3.0+ has it too, but expr must be used:

test.groupBy('x').agg(
    F.expr("count_if(y > 12453) y_cnt"),
    F.expr("count_if(z > 230) z_cnt")
).show()
# +---+-----+-----+
# |  x|y_cnt|z_cnt|
# +---+-----+-----+
# | bn|    0|    0|
# | mb|    2|    2|
# +---+-----+-----+

Upvotes: 3

count function skip null values so you can try this:

import pyspark.sql.functions as F

def count_with_condition(cond):
    return F.count(F.when(cond, True))

and also function in this repo: kolang

Upvotes: 3

rwitzel
rwitzel

Reputation: 1828

Since Spark 3.0.0 there is count_if(exp), see Spark function documentation

Upvotes: 2

newleaf
newleaf

Reputation: 2457

Based on @Psidom answer, my answer is as following

from pyspark.sql.functions import col,when,count

test.groupBy("x").agg(
    count(when(col("y") > 12453, True)),
    count(when(col("z") > 230, True))
).show()

Upvotes: 37

akuiper
akuiper

Reputation: 215117

count doesn't sum Trues, it only counts the number of non null values. To count the True values, you need to convert the conditions to 1 / 0 and then sum:

import pyspark.sql.functions as F

cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
test.groupBy('x').agg(
    cnt_cond(F.col('y') > 12453).alias('y_cnt'), 
    cnt_cond(F.col('z') > 230).alias('z_cnt')
).show()
+---+-----+-----+
|  x|y_cnt|z_cnt|
+---+-----+-----+
| bn|    0|    0|
| mb|    2|    2|
+---+-----+-----+

Upvotes: 76

Related Questions