max
max

Reputation: 52373

argmax in Spark DataFrames: how to retrieve the row with the maximum value

Given a Spark DataFrame df, I want to find the maximum value in a certain numeric column 'values', and obtain the row(s) where that value was reached. I can of course do this:

# it doesn't matter if I use scala or python, 
# since I hope I get this done with DataFrame API
import pyspark.sql.functions as F
max_value = df.select(F.max('values')).collect()[0][0]
df.filter(df.values == max_value).show()

but this is inefficient since it requires two passes through df.

pandas.Series/DataFrame and numpy.array have argmax/idxmax methods that do this efficiently (in one pass). So does standard python (built-in function max accepts a key parameter, so it can be used to find the index of the highest value).

What is the right approach in Spark? Note that I don't mind whether I get all the rows that where the maximum value is achieved, or just some arbitrary (non-empty!) subset of those rows.

Upvotes: 18

Views: 22568

Answers (2)

zero323
zero323

Reputation: 330433

If schema is Orderable (schema contains only atomics / arrays of atomics / recursively orderable structs) you can use simple aggregations:

Python:

df.select(F.max(
    F.struct("values", *(x for x in df.columns if x != "values"))
)).first()

Scala:

df.select(max(struct(
    $"values" +: df.columns.collect {case x if x!= "values" => col(x)}: _*
))).first

Otherwise you can reduce over Dataset (Scala only) but it requires additional deserialization:

type T = ???

df.reduce((a, b) => if (a.getAs[T]("values") > b.getAs[T]("values")) a else b)

You can also oredrBy and limit(1) / take(1):

Scala:

df.orderBy(desc("values")).limit(1)
// or
df.orderBy(desc("values")).take(1)

Python:

df.orderBy(F.desc('values')).limit(1)
# or
df.orderBy(F.desc("values")).take(1)

Upvotes: 17

Alberto Bonsanto
Alberto Bonsanto

Reputation: 18042

Maybe it's an incomplete answer but you can use DataFrame's internal RDD, apply the max method and get the maximum record using a determined key.

a = sc.parallelize([
    ("a", 1, 100),
    ("b", 2, 120),
    ("c", 10, 1000),
    ("d", 14, 1000)
  ]).toDF(["name", "id", "salary"])

a.rdd.max(key=lambda x: x["salary"]) # Row(name=u'c', id=10, salary=1000)

Upvotes: 4

Related Questions