user1896943
user1896943

Reputation: 11

Weighted mean median quartiles in Spark

I have a Spark SQL dataframe:

id Value Weights
1 2 4
1 5 2
2 1 4
2 6 2
2 9 4
3 2 4

I need to groupBy by 'id' and aggregate to get the weighted mean, median, and quartiles of the values per 'id'. What is the best way to do this?

Upvotes: 1

Views: 985

Answers (1)

ZygD
ZygD

Reputation: 24376

Before the calculation you should do a small transformation to your Value column:

F.explode(F.array_repeat('Value', F.col('Weights').cast('int')))
  • array_repeat creates an array out of your number - the number inside the array will be repeated as many times as is specified in the column 'Weights' (casting to int is necessary, because array_repeat expects this column to be of int type. After this part the first value of 2 will be transformed into [2,2,2,2].

  • Then, explode will create a row for every element in the array. So, the line [2,2,2,2] will be transformed into 4 rows, each containing an integer 2.

  • Then you can calculate statistics, the results will have weights applied, as your dataframe is now transformed according to the weights.

Full example:

from pyspark.sql import SparkSession, functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(
    [(1, 2, 4),
     (1, 5, 2),
     (2, 1, 4),
     (2, 6, 2),
     (2, 9, 4),
     (3, 2, 4)],
    ['id', 'Value', 'Weights']
)

df = df.select('id', F.explode(F.array_repeat('Value', F.col('Weights').cast('int'))))
df = (df
    .groupBy('id')
    .agg(F.mean('col').alias('weighted_mean'),
         F.expr('percentile(col, 0.5)').alias('weighted_median'),
         F.expr('percentile(col, 0.25)').alias('weighted_lower_quartile'),
         F.expr('percentile(col, 0.75)').alias('weighted_upper_quartile')))
df.show()
#+---+-------------+---------------+-----------------------+-----------------------+
#| id|weighted_mean|weighted_median|weighted_lower_quartile|weighted_upper_quartile|
#+---+-------------+---------------+-----------------------+-----------------------+
#|  1|          3.0|            2.0|                    2.0|                   4.25|
#|  2|          5.2|            6.0|                    1.0|                    9.0|
#|  3|          2.0|            2.0|                    2.0|                    2.0|
#+---+-------------+---------------+-----------------------+-----------------------+

Upvotes: 3

Related Questions