에스파파
에스파파

Reputation: 57

Pyspark group by and count data with condition

I would like to solve some problems using group by functions. Let me show you my case. The data I have is like this.

| columnA | columnB | columnC | columnD | columnE |
| ------- | ------- | ------- | ------- | ------- |
| PersonA | DataOne | 20210101|    1    |    2    |
| PersonA | DataOne | 20210102|    2    |    4    |
| PersonA | DataOne | 20210102|    3    |    4    |
| PersonA | DataTwo | 20201226|    2    |    4    |
| PersonA | DataTwo | 20201226|    7    |    1    |
| PersonA | DataTwo | 20201227|    3    |    2    |
| PersonB | DataOne | 20201225|    1    |    3    |
| PersonB | DataTwo | 20201225|    2    |    4    |
| PersonB | DataTwo | 20201226|    1    |    2    |

then, the thing I want to do aggregate columnD, E grouping by column A,B,C but use only max(columnC).

I did the job like this way code below, But I have been wondering the way simpler and faster.

my_df = (The data above)
my_df_max = my_df.groupBy("columnA","columnB").agg(max("columnC").alias("columnC"))
result = my_df\
    .groupBy("columnA","columnB","columnC")\
    .agg(count("columnD").alias("columnD"),sum("columnE").alias("columnE"))\
    .alias("tempA")\
    .join(my_df_max.alias("tempB"), (col("tempA.columnA") == col("tempB.columnA")) & (col("tempA.columnB") == col("tempB.columnB")) & (col("tempA.columnC") == col("tempB.columnC")))\
    .select(col("tempA.columnA"),col("tempA.columnB"), col("tempA.columnC"), col("columnD"), col("columnE"))

And the result I expect is way like below.

|columnA|columnB|columnC |columnD|columnE|
|-------|-------|--------|-------|-------|
|PersonA|DataOne|20210102|   2   |   8   |
|PersonA|DataTwo|20201227|   1   |   2   |
|PersonB|DataOne|20201225|   1   |   3   |
|PersonB|DataTwo|20201226|   1   |   2   |

And If I happen to know the code way and SQL way to realize this job, I would be be very pleased.

Upvotes: 1

Views: 1382

Answers (2)

mck
mck

Reputation: 42332

Spark SQL way to do this. You can filter the rows with max columnC using rank() over an appropriate window, and then do the group by and aggregation.

df.createOrReplaceTempView('df')

result = spark.sql("""
    SELECT columnA, columnB, columnC, count(columnD) columnD, sum(columnE) columnE 
    FROM (
        SELECT *, rank() over(partition by columnA, columnB order by columnC desc) r 
        FROM df
    )
    WHERE r = 1
    GROUP BY columnA, columnB, columnC
""")

result.show()
+-------+-------+--------+-------+-------+
|columnA|columnB| columnC|columnD|columnE|
+-------+-------+--------+-------+-------+
|PersonB|DataOne|20201225|      1|      3|
|PersonA|DataOne|20210102|      2|      8|
|PersonB|DataTwo|20201226|      1|      2|
|PersonA|DataTwo|20201227|      1|      2|
+-------+-------+--------+-------+-------+

Upvotes: 2

akuiper
akuiper

Reputation: 214927

One possibly more concise option is to filter your data frame by the maximum value in column C first and then do aggregation, (assuming your spark data frame is named sdf):

import pyspark.sql.functions as f

sdf.withColumn('rankC', f.expr('dense_rank() over (partition by columnA, columnB order by columnC desc)'))\
    .filter(f.col('rankC') == 1)\
    .groupBy('columnA', 'columnB', 'columnC')\
    .agg(f.count('columnD').alias('columnD'), f.sum('columnE').alias('columnE'))\
    .show()

+-------+-------+--------+-------+-------+
|columnA|columnB| columnC|columnD|columnE|
+-------+-------+--------+-------+-------+
|PersonB|DataOne|20201225|      1|      3|
|PersonA|DataOne|20210102|      2|      8|
|PersonB|DataTwo|20201226|      1|      2|
|PersonA|DataTwo|20201227|      1|      2|
+-------+-------+--------+-------+-------+

Upvotes: 3

Related Questions