Ali
Ali

Reputation: 8100

Pyspark divide column by its subtotals grouped by another column

My problem is similar to this and this. Both posts show how to divide a column value by the total sum of the same column. In my case I want to divide the values of a column by the sum of subtotals. Subtotal is calculated by grouping the column values depending on another column. I am slightly modifying the example mentioned in the links shared above.

Here is my dataframe

df = [[1,'CAT1',10], [2, 'CAT1', 11], [3, 'CAT2', 20], [4, 'CAT2', 22], [5, 'CAT3', 30]]
df = spark.createDataFrame(df, ['id', 'category', 'consumption'])
df.show()
+---+--------+-----------+
| id|category|consumption|
+---+--------+-----------+
|  1|    CAT1|         10|
|  2|    CAT1|         11|
|  3|    CAT2|         20|
|  4|    CAT2|         22|
|  5|    CAT3|         30|
+---+--------+-----------+

I want to divide "consumption" value by the total of grouped "category" and put the value in a column "normalized" as below.

The subtotals doesn't need to be in the output(number 21, 42 and 30 in column consumption) enter image description here

What I've achieved so far df.crossJoin(

df.groupby('category').agg(F.sum('consumption').alias('sum_'))
).withColumn("normalized", F.col("consumption")/F.col("sum_"))\
.show()

+---+--------+-----------+--------+----+-------------------+
| id|category|consumption|category|sum_|         normalized|
+---+--------+-----------+--------+----+-------------------+
|  1|    CAT1|         10|    CAT2|  42|0.23809523809523808|
|  2|    CAT1|         11|    CAT2|  42| 0.2619047619047619|
|  1|    CAT1|         10|    CAT1|  21|0.47619047619047616|
|  2|    CAT1|         11|    CAT1|  21| 0.5238095238095238|
|  1|    CAT1|         10|    CAT3|  30| 0.3333333333333333|
|  2|    CAT1|         11|    CAT3|  30|0.36666666666666664|
|  3|    CAT2|         20|    CAT2|  42|0.47619047619047616|
|  4|    CAT2|         22|    CAT2|  42| 0.5238095238095238|
|  5|    CAT3|         30|    CAT2|  42| 0.7142857142857143|
|  3|    CAT2|         20|    CAT1|  21| 0.9523809523809523|
|  4|    CAT2|         22|    CAT1|  21| 1.0476190476190477|
|  5|    CAT3|         30|    CAT1|  21| 1.4285714285714286|
|  3|    CAT2|         20|    CAT3|  30| 0.6666666666666666|
|  4|    CAT2|         22|    CAT3|  30| 0.7333333333333333|
|  5|    CAT3|         30|    CAT3|  30|                1.0|
+---+--------+-----------+--------+----+-------------------+

Upvotes: 1

Views: 3111

Answers (2)

cph_sto
cph_sto

Reputation: 7585

This is another way of solving the problem as proposed by the OP, but without using joins().

joins() in general are costly operations and should be avoided when ever possible.

# We first register our DataFrame as temporary SQL view
df.registerTempTable('table_view')
df = sqlContext.sql("""select id, category, 
                       consumption/sum(consumption) over (partition by category) as normalize
                       from table_view""")
df.show()
+---+--------+-------------------+
| id|category|          normalize|
+---+--------+-------------------+
|  3|    CAT2|0.47619047619047616|
|  4|    CAT2| 0.5238095238095238|
|  1|    CAT1|0.47619047619047616|
|  2|    CAT1| 0.5238095238095238|
|  5|    CAT3|                1.0|
+---+--------+-------------------+

Note: """ has been used to have multiline statements for the sake of visibility and neatness. With simple 'select id ....' that wouldn't work if you try to spread your statement over multiple lines. Needless to say, the final result will be the same.

Upvotes: 1

cronoik
cronoik

Reputation: 19320

You can do basically the same as in the links you have already mentioned. The only difference is that you have to calculate the subtotals before with groupby and sum:

import pyspark.sql.functions as F
df = df.join(df.groupby('category').sum('consumption'), 'category')
df = df.select('id', 'category', F.round(F.col('consumption')/F.col('sum(consumption)'), 2).alias('normalized'))
df.show()

Output:

+---+--------+----------+ 
| id|category|normalized| 
+---+--------+----------+ 
|  3|    CAT2|      0.48| 
|  4|    CAT2|      0.52| 
|  1|    CAT1|      0.48| 
|  2|    CAT1|      0.52| 
|  5|    CAT3|       1.0| 
+---+--------+----------+ 

Upvotes: 1

Related Questions