Thomas
Thomas

Reputation: 5104

GroupBy column and filter rows with maximum value in Pyspark

I am almost certain this has been asked before, but a search through stackoverflow did not answer my question. Not a duplicate of [2] since I want the maximum value, not the most frequent item. I am new to pyspark and trying to do something really simple: I want to groupBy column "A" and then only keep the row of each group that has the maximum value in column "B". Like this:

df_cleaned = df.groupBy("A").agg(F.max("B"))

Unfortunately, this throws away all other columns - df_cleaned only contains the columns "A" and the max value of B. How do I instead keep the rows? ("A", "B", "C"...)

Upvotes: 67

Views: 148864

Answers (5)

Disenchanted
Disenchanted

Reputation: 613

Another solution is to number the rows via row_number() using a window partitioned by A in the order of B. This solution is close to the one by @pault, but when there are several rows with the maximum value, it only keeps one of them, which I find better.

Given the same example:

data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
    ('b', 3)
]
df = spark.createDataFrame(data, ["A", "B"])
df.show()

The row_number solution is:

w = Window.partitionBy('A').orderBy('B')
df_collect = df.withColumn('row_number', F.row_number().over(w)) \
    .filter(F.col('row_number') == 1) \
    .drop('row_number') \
    .show()

+---+---+
|  A|  B|
+---+---+
|  a|  8|
|  b|  3|
+---+---+

I also extended the benchmark from @Fernando Wittmann, both solutions run in about the same time:

The dataframe:

N_SAMPLES = 600000
N_PARTITIONS = 1000
MAX_VALUE = 100
data = zip(
    [random.randint(0, N_PARTITIONS-1) for i in range(N_SAMPLES)],
    [random.randint(0, MAX_VALUE) for i in range(N_SAMPLES)],
    list(range(N_SAMPLES))
)
df = spark.createDataFrame(data, ["A", "B", "C"])

row_number approach:

%%timeit
w = Window.partitionBy('A').orderBy(F.col('B').desc())
df_collect = df.withColumn('row_number', F.row_number().over(w)) \
    .filter(F.col('row_number') == 1) \
    .drop('row_number') \
    .collect()
313 ms ± 19.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

== max approach:

%%timeit
w = Window.partitionBy('A')
df_collect = df.withColumn('maxB', F.max('B').over(w))\
    .where(F.col('B') == F.col('maxB'))\
    .drop('maxB')\
    .collect()
328 ms ± 24.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

leftsemi approach:

%%timeit
df_collect = df.join(df.groupBy('A').agg(F.max('B').alias('B')),on='B',how='leftsemi').collect()
516 ms ± 19.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Upvotes: 1

Fernando Wittmann
Fernando Wittmann

Reputation: 2537

There are two great solutions, so I decided to benchmark them. First let me define a bigger dataframe:

N_SAMPLES = 600000
N_PARTITIONS = 1000
MAX_VALUE = 100
data = zip([random.randint(0, N_PARTITIONS-1) for i in range(N_SAMPLES)],
          [random.randint(0, MAX_VALUE) for i in range(N_SAMPLES)],
          list(range(N_SAMPLES))
          )
df = spark.createDataFrame(data, ["A", "B", "C"])
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|118| 91|  0|
|439| 80|  1|
|779| 77|  2|
|444| 14|  3|
...

Benchmarking @pault's solution:

%%timeit
w = Window.partitionBy('A')
df_collect = df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .collect()

gives

655 ms ± 70.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Benchmarking @ndricca's solution:

%%timeit
df_collect = df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').collect()

gives

1 s ± 49.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So, @pault's solution seems to be 1.5x faster. Feedbacks on this benchmark are very welcome.

Upvotes: 8

user9875189
user9875189

Reputation: 309

just want to add scala spark version of @ndricca´s answer in case anyone needs it:

val data = Seq(("a", 5,"c"), ("a",8,"d"),("a",7,"e"),("b",1,"f"),("b",3,"g"))
val df = data.toDF("A","B","C")
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  5|  c|
|  a|  8|  d|
|  a|  7|  e|
|  b|  1|  f|
|  b|  3|  g|
+---+---+---+

val rightdf = df.groupBy("A").max("B")
rightdf.show()
+---+------+
|  A|max(B)|
+---+------+
|  b|     3|
|  a|     8|
+---+------+

val resdf = df.join(rightdf, df("B") === rightdf("max(B)"), "leftsemi")
resdf.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  8|  d|
|  b|  3|  g|
+---+---+---+

Upvotes: 3

ndricca
ndricca

Reputation: 532

Another possible approach is to apply join the dataframe with itself specifying "leftsemi". This kind of join includes all columns from the dataframe on the left side and no columns on the right side.

For example:

import pyspark.sql.functions as f
data = [
    ('a', 5, 'c'),
    ('a', 8, 'd'),
    ('a', 7, 'e'),
    ('b', 1, 'f'),
    ('b', 3, 'g')
]
df = sqlContext.createDataFrame(data, ["A", "B", "C"])
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  5|  c|
|  a|  8|  d|
|  a|  7|  e|
|  b|  1|  f|
|  b|  3|  g|
+---+---+---+

Max value of column B by by column A can be selected doing:

df.groupBy('A').agg(f.max('B')
+---+---+
|  A|  B|
+---+---+
|  a|  8|
|  b|  3|
+---+---+

Using this expression as a right side in a left semi join, and renaming the obtained column max(B) back to its original name B, we can obtain the result needed:

df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').show()
+---+---+---+
|  B|  A|  C|
+---+---+---+
|  3|  b|  g|
|  8|  a|  d|
+---+---+---+

The physical plan behind this solution and the one from accepted answer are different and it is still not clear to me which one will perform better on large dataframes.

The same result can be obtained using spark SQL syntax doing:

df.registerTempTable('table')
q = '''SELECT *
FROM table a LEFT SEMI
JOIN (
    SELECT 
        A,
        max(B) as max_B
    FROM table
    GROUP BY A
    ) t
ON a.A=t.A AND a.B=t.max_B
'''
sqlContext.sql(q).show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  b|  3|  g|
|  a|  8|  d|
+---+---+---+

Upvotes: 21

pault
pault

Reputation: 43504

You can do this without a udf using a Window.

Consider the following example:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
    ('b', 3)
]
df = sqlCtx.createDataFrame(data, ["A", "B"])
df.show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  5|
#|  a|  8|
#|  a|  7|
#|  b|  1|
#|  b|  3|
#+---+---+

Create a Window to partition by column A and use this to compute the maximum of each group. Then filter out the rows such that the value in column B is equal to the max.

from pyspark.sql import Window
w = Window.partitionBy('A')
df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  8|
#|  b|  3|
#+---+---+

Or equivalently using pyspark-sql:

df.registerTempTable('table')
q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
sqlCtx.sql(q).show()
#+---+---+
#|  A|  B|
#+---+---+
#|  b|  3|
#|  a|  8|
#+---+---+

Upvotes: 81

Related Questions