Reputation: 2457
I have a dataframe
test = spark.createDataFrame([('bn', 12452, 221), ('mb', 14521, 330), ('bn', 2, 220), ('mb', 14520, 331)], ['x', 'y', 'z'])
test.show()
# +---+-----+---+
# | x| y| z|
# +---+-----+---+
# | bn|12452|221|
# | mb|14521|330|
# | bn| 2|220|
# | mb|14520|331|
# +---+-----+---+
I need to count the rows based on a condition:
test.groupBy("x").agg(count(col("y") > 12453), count(col("z") > 230)).show()
which gives
+---+------------------+----------------+
| x|count((y > 12453))|count((z > 230))|
+---+------------------+----------------+
| bn| 2| 2|
| mb| 2| 2|
+---+------------------+----------------+
It's just the count of the rows, not the count for certain conditions.
Upvotes: 39
Views: 103100
Reputation: 24478
Spark 3.5+ has count_if
in Python API:
from pyspark.sql import functions as F
test.groupBy('x').agg(
F.count_if(F.col('y') > 12453).alias('y_cnt'),
F.count_if(F.col('z') > 230).alias('z_cnt')
).show()
# +---+-----+-----+
# | x|y_cnt|z_cnt|
# +---+-----+-----+
# | bn| 0| 0|
# | mb| 2| 2|
# +---+-----+-----+
Spark 3.0+ has it too, but expr
must be used:
test.groupBy('x').agg(
F.expr("count_if(y > 12453) y_cnt"),
F.expr("count_if(z > 230) z_cnt")
).show()
# +---+-----+-----+
# | x|y_cnt|z_cnt|
# +---+-----+-----+
# | bn| 0| 0|
# | mb| 2| 2|
# +---+-----+-----+
Upvotes: 3
Reputation: 39
count function skip null
values so you can try this:
import pyspark.sql.functions as F
def count_with_condition(cond):
return F.count(F.when(cond, True))
and also function in this repo: kolang
Upvotes: 3
Reputation: 1828
Since Spark 3.0.0 there is count_if(exp)
, see Spark function documentation
Upvotes: 2
Reputation: 2457
Based on @Psidom answer, my answer is as following
from pyspark.sql.functions import col,when,count
test.groupBy("x").agg(
count(when(col("y") > 12453, True)),
count(when(col("z") > 230, True))
).show()
Upvotes: 37
Reputation: 215117
count
doesn't sum Trues, it only counts the number of non null values. To count the True values, you need to convert the conditions to 1 / 0 and then sum
:
import pyspark.sql.functions as F
cnt_cond = lambda cond: F.sum(F.when(cond, 1).otherwise(0))
test.groupBy('x').agg(
cnt_cond(F.col('y') > 12453).alias('y_cnt'),
cnt_cond(F.col('z') > 230).alias('z_cnt')
).show()
+---+-----+-----+
| x|y_cnt|z_cnt|
+---+-----+-----+
| bn| 0| 0|
| mb| 2| 2|
+---+-----+-----+
Upvotes: 76