BhishanPoudel
BhishanPoudel

Reputation: 17144

How to find row names where values are maximum in a column in Pyspark

I have a table like this:

+-------+-----+------+------+
|user_id|apple|good banana|carrot|
+-------+-----+------+------+
| user_0|    0|     3|     1|
| user_1|    1|     0|     2|
| user_2|    5|     1|     2|
+-------+-----+------+------+

Here, for each fruits, I want to get the list of customers who bought the most items. The required output is following:

                max_user max_count
apple           [user_2]         5
banana          [user_0]         3
carrot  [user_1, user_2]         2

MWE

import numpy as np
import pandas as pd
import pyspark
from pyspark.sql import functions as F

spark = pyspark.sql.SparkSession.builder.getOrCreate()
sc = spark.sparkContext
sqlContext = pyspark.SQLContext(sc)


# pandas dataframe
pdf = pd.DataFrame({'user_id': ['user_0','user_1','user_2'],
                   'apple': [0,1,5],
                   'good banana': [3,0,1],
                   'carrot': [1,2,2]})


# spark dataframe
df = sqlContext.createDataFrame(pdf)
# df.show()


df.createOrReplaceTempView("grocery")
spark.sql('select * from grocery').show()

Question 1

How to get the required output using Pyspark?

Question 2

How to get the required output using Pyspark sql?

References

I have already done some research and searched multiple pages. So far I have come up with one close answer, but it requires transposed table and here my table is normal. Also, I am learning multiple methods such as Spark method and SQL method.

Upvotes: 2

Views: 465

Answers (3)

mck
mck

Reputation: 42332

Pyspark solution. Similar to the pandas solutions, where you first melt the dataframe using stack, then filter the rows with max count using rank, group by fruit, and get a list of users using collect_list.

from pyspark.sql import functions as F, Window

df2 = df.selectExpr(
    'user_id',
    'stack(3, ' + ', '.join(["'%s', %s" % (c, c) for c in df.columns[1:]]) + ') as (fruit, items)'
).withColumn(
    'rn',
    F.rank().over(Window.partitionBy('fruit').orderBy(F.desc('items')))
).filter('rn = 1').groupBy('fruit').agg(
    F.collect_list('user_id').alias('max_user'),
    F.max('items').alias('max_count')
)

df2.show()
+------+----------------+---------+
| fruit|        max_user|max_count|
+------+----------------+---------+
| apple|        [user_2]|        5|
|banana|        [user_0]|        3|
|carrot|[user_1, user_2]|        2|
+------+----------------+---------+

For Spark SQL:

df.createOrReplaceTempView("grocery")

df2 = spark.sql("""
    select
        fruit,
        collect_list(user_id) as max_user,
        max(items) as max_count
    from (
        select *,
            rank() over (partition by fruit order by items desc) as rn
        from (
            select
                user_id,
                stack(3, 'apple', apple, 'banana', banana, 'carrot', carrot) as (fruit, items)
            from grocery
        )
    )
    where rn = 1 group by fruit
""")

df2.show()
+------+----------------+---------+
| fruit|        max_user|max_count|
+------+----------------+---------+
| apple|        [user_2]|        5|
|banana|        [user_0]|        3|
|carrot|[user_1, user_2]|        2|
+------+----------------+---------+

Upvotes: 3

BhishanPoudel
BhishanPoudel

Reputation: 17144

For pandas you can do this:

pdf = pd.DataFrame({'user_id': ['user_0','user_1','user_2'],
                   'apple': [0,1,5],
                   'banana': [3,0,1],
                   'carrot': [1,2,2]})


ans = pdf.set_index('user_id').apply(lambda s: pd.Series(
    [(s[s==s.max()]).index.tolist(), s.max()],
    index=['max_user','max_count']
    )).T

ans

This gives:

                max_user max_count
apple           [user_2]         5
banana          [user_0]         3
carrot  [user_1, user_2]         2

Upvotes: 1

Quang Hoang
Quang Hoang

Reputation: 150735

You can try melt, filter the max values, then groupby().agg():

s = df.melt('user_id')
max_val = s.groupby('variable')['value'].transform('max')

(s[s['value']==max_val].groupby(['variable'])
     .agg(max_user=('user_id',list),
          max_count=('value', 'first'))
)

Output:

                  max_user  max_count
variable                             
apple             [user_2]          5
banana            [user_0]          3
carrot    [user_1, user_2]          2

Upvotes: 1

Related Questions