Reputation: 52373
Given a Spark DataFrame df
, I want to find the maximum value in a certain numeric column 'values'
, and obtain the row(s) where that value was reached. I can of course do this:
# it doesn't matter if I use scala or python,
# since I hope I get this done with DataFrame API
import pyspark.sql.functions as F
max_value = df.select(F.max('values')).collect()[0][0]
df.filter(df.values == max_value).show()
but this is inefficient since it requires two passes through df
.
pandas.Series
/DataFrame
and numpy.array
have argmax
/idxmax
methods that do this efficiently (in one pass). So does standard python (built-in function max
accepts a key parameter, so it can be used to find the index of the highest value).
What is the right approach in Spark? Note that I don't mind whether I get all the rows that where the maximum value is achieved, or just some arbitrary (non-empty!) subset of those rows.
Upvotes: 18
Views: 22568
Reputation: 330433
If schema is Orderable
(schema contains only atomics / arrays of atomics / recursively orderable structs) you can use simple aggregations:
Python:
df.select(F.max(
F.struct("values", *(x for x in df.columns if x != "values"))
)).first()
Scala:
df.select(max(struct(
$"values" +: df.columns.collect {case x if x!= "values" => col(x)}: _*
))).first
Otherwise you can reduce over Dataset
(Scala only) but it requires additional deserialization:
type T = ???
df.reduce((a, b) => if (a.getAs[T]("values") > b.getAs[T]("values")) a else b)
You can also oredrBy
and limit(1)
/ take(1)
:
Scala:
df.orderBy(desc("values")).limit(1)
// or
df.orderBy(desc("values")).take(1)
Python:
df.orderBy(F.desc('values')).limit(1)
# or
df.orderBy(F.desc("values")).take(1)
Upvotes: 17
Reputation: 18042
Maybe it's an incomplete answer but you can use DataFrame
's internal RDD
, apply the max
method and get the maximum record using a determined key.
a = sc.parallelize([
("a", 1, 100),
("b", 2, 120),
("c", 10, 1000),
("d", 14, 1000)
]).toDF(["name", "id", "salary"])
a.rdd.max(key=lambda x: x["salary"]) # Row(name=u'c', id=10, salary=1000)
Upvotes: 4