Juan David
Juan David

Reputation: 381

how to calculate max value in some columns per row in pyspark

I have a data frame read with sqlContext.sql function in pyspark. This contains 4 numerics columns with information per client (this is the key id). I need to calculate the max value per client and join this value to the data frame:

+--------+-------+-------+-------+-------+
|ClientId|m_ant21|m_ant22|m_ant23|m_ant24|
+--------+-------+-------+-------+-------+
|       0|   null|   null|   null|   null|
|       1|   null|   null|   null|   null|
|       2|   null|   null|   null|   null|
|       3|   null|   null|   null|   null|
|       4|   null|   null|   null|   null|
|       5|   null|   null|   null|   null|
|       6|     23|     13|     17|      8|
|       7|   null|   null|   null|   null|
|       8|   null|   null|   null|   null|
|       9|   null|   null|   null|   null|
|      10|     34|      2|      4|      0|
|      11|      0|      0|      0|      0|
|      12|      0|      0|      0|      0|
|      13|      0|      0|     30|      0|
|      14|   null|   null|   null|   null|
|      15|   null|   null|   null|   null|
|      16|     37|     29|     29|     29|
|      17|      0|      0|     16|      0|
|      18|      0|      0|      0|      0|
|      19|   null|   null|   null|   null|
+--------+-------+-------+-------+-------+

In this case, the max value to the client six is 23 and the client ten is 30. the null is naturally null in the new column.

Please help me showing how can i do this operation.

Upvotes: 8

Views: 9078

Answers (2)

SheepPerplexed
SheepPerplexed

Reputation: 1152

There is a function for that: pyspark.sql.functions.greatest.

>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
>>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
[Row(greatest=4)]

The example was taken directly from the docs.

(Least does the opposite.)

Upvotes: 14

Konrad Kostrzewa
Konrad Kostrzewa

Reputation: 835

I think combing values to a list and than finding max on it would be the simplest approach.

from pyspark.sql.types import *

schema = StructType([
    StructField("ClientId", IntegerType(), True),
    StructField("m_ant21", IntegerType(), True),
    StructField("m_ant22", IntegerType(), True),
    StructField("m_ant23", IntegerType(), True),
    StructField("m_ant24", IntegerType(), True)
])

df = spark\
    .createDataFrame(
        data=[(0, None, None, None, None),
             (1, 23, 13, 17, 99),
             (2, 0, 0, 0, 1),
             (3, 0, None, 1, 0)],
        schema=schema)

import pyspark.sql.functions as F

def agg_to_list(m21,m22,m23,m24):
    return [m21,m22,m23,m24]

u_agg_to_list = F.udf(agg_to_list, ArrayType(IntegerType()))

df2 = df.withColumn('all_values', u_agg_to_list('m_ant21', 'm_ant22', 'm_ant23', 'm_ant24'))\
        .withColumn('max', F.sort_array("all_values", False)[0])\
        .select('ClientId', 'max')

df2.show()

Outputs :

+--------+----+
|ClientId|max |
+--------+----+
|0       |null|
|1       |99  |
|2       |1   |
|3       |1   |
+--------+----+

Upvotes: 4

Related Questions