DanG
DanG

Reputation: 741

Add total per group as a new row in dataframe in Pyspark

Referring to my previous question Here if I trying to compute and add total row, for each brand , parent and week_num (total of usage)

Here is dummy sample :

df0 = spark.createDataFrame(
    [
        (2, "A", "A2", "A2web", 2500),
        (2, "A", "A2", "A2TV", 3500),
        (4, "A", "A1", "A2app", 5500),
        (4, "A", "AD", "ADapp", 2000),
        (4, "B", "B25", "B25app", 7600),
        (4, "B", "B26", "B26app", 5600),
        (5, "C", "c25", "c25app", 2658),
        (5, "C", "c27", "c27app", 1100),
        (5, "C", "c28", "c26app", 1200),
    ],
    ["week_num", "parent", "brand", "channel", "usage"],
)

This snippet add total row per channel

# Group by and sum to get the totals
totals = (
    df0.groupBy(["week_num", "parent", "brand"])
    .agg(f.sum("usage").alias("usage"))
    .withColumn("channel", f.lit("Total"))
)

# create a temp variable to sort
totals = totals.withColumn("sort_id", f.lit(2))
df0 = df0.withColumn("sort_id", f.lit(1))

# Union dataframes, drop temp variable and show
df1 = df0.unionByName(totals).sort(["week_num", "parent", "brand", "sort_id"])

df1.show()

result:

+--------+------+-----+-------+-----+
|week_num|parent|brand|channel|usage|
+--------+------+-----+-------+-----+
|       2|     A|   A2|  A2web| 2500|
|       2|     A|   A2|   A2TV| 3500|
|       2|     A|   A2|  Total| 6000|
|       4|     A|   A1|  A2app| 5500|
|       4|     A|   A1|  Total| 5500|
|       4|     A|   AD|  ADapp| 2000|
|       4|     A|   AD|  Total| 2000|
|       4|     B|  B25| B25app| 7600|
|       4|     B|  B25|  Total| 7600|
|       4|     B|  B26| B26app| 5600|
|       4|     B|  B26|  Total| 5600|
|       5|     C|  c25| c25app| 2658|
|       5|     C|  c25|  Total| 2658|
|       5|     C|  c27| c27app| 1100|
|       5|     C|  c27|  Total| 1100|
|       5|     C|  c28| c26app| 1200|
|       5|     C|  c28|  Total| 1200|
+--------+------+-----+-------+-----+

That is ok for channel column, in order to to get something like below, I simply repeat the first process groupby+sum and then union the result back

+--------+------+-----+-------+-----+ 
|week_num|parent|brand|channel|usage|
+--------+------+-----+-------+-----+
|       2|     A|   A2|  A2web| 2500|
|       2|     A|   A2|   A2TV| 3500|
|       2|     A|   A2|  Total| 6000|
|       2|     A|Total|       | 6000|
|       2| Total|     |       | 6000|

Here in two steps

# add brand total row
df2 = (
    df0.groupBy(["week_num", "parent"])
    .agg(f.sum("usage").alias("usage"))
    .withColumn("brand", f.lit("Total"))
    .withColumn("channel", f.lit(""))
)
df2 = df1.unionByName(df2).sort(["week_num", "parent", "brand", "channel"])

# add weeknum total row
df3 = (
    df0.groupBy(["week_num"])
    .agg(f.sum("usage").alias("usage"))
    .withColumn("parent", f.lit("Total"))
    .withColumn("brand", f.lit(""))
    .withColumn("channel", f.lit(""))
)
df3 = df2.unionByName(df3).sort(["week_num", "parent", "brand", "channel"])

result:

+--------+------+-----+-------+-----+
|week_num|parent|brand|channel|usage|
+--------+------+-----+-------+-----+
|       2|     A|   A2|   A2TV| 3500|
|       2|     A|   A2|  A2web| 2500|
|       2|     A|   A2|  Total| 6000|
|       2|     A|Total|       | 6000|
|       2| Total|     |       | 6000|
|       4|     A|   A1|  A2app| 5500|
|       4|     A|   A1|  Total| 5500|
|       4|     A|   AD|  ADapp| 2000|
|       4|     A|   AD|  Total| 2000|
|       4|     A|Total|       | 7500|
|       4|     B|  B25| B25app| 7600|
|       4|     B|  B25|  Total| 7600|
|       4|     B|  B26| B26app| 5600|
|       4|     B|  B26|  Total| 5600|
|       4|     B|Total|       |13200|
|       4| Total|     |       |20700|
|       5|     C|Total|       | 4958|
|       5|     C|  c25|  Total| 2658|
|       5|     C|  c25| c25app| 2658|
|       5|     C|  c27|  Total| 1100|
+--------+------+-----+-------+-----+

First question, is there any alternative approach or more efficient way without repetition? and second, what if I want to show total always at top per each group , regardless of parent/brand/channel alphabetical name, How can I sort this. like this:(this is dummy data but I hope it is clear enough)

+--------+------+-----+-------+-----+
|week_num|parent|brand|channel|usage|
+--------+------+-----+-------+-----+
|       2| Total|     |       | 6000|
|       2|     A|Total|       | 6000|
|       2|     A|   A2|  Total| 6000|
|       2|     A|   A2|   A2TV| 3500|
|       2|     A|   A2|  A2web| 2500|
|       4| Total|     |       |20700|
|       4|     A|Total|       | 7500|
|       4|     B|Total|       |13200|
|       4|     A|   A1|  Total| 5500| 
|       4|     A|   A1|  A2app| 5500|
|       4|     A|   AD|  Total| 2000|
|       4|     A|   AD|  ADapp| 2000|
|       4|     B|  B25|  Total| 7600|
|       4|     B|  B25| B25app| 7600|
|       4|     B|  B26|  Total| 5600|
|       4|     B|  B26| B26app| 5600|

Upvotes: 0

Views: 2583

Answers (1)

Steven
Steven

Reputation: 15258

I think you just need the rollup method.

agg_df = (
    df.rollup(["week_num", "parent", "brand", "channel"])
    .agg(F.sum("usage").alias("usage"), F.grouping_id().alias("lvl"))
    .orderBy(agg_cols)
)

agg_df.show()
+--------+------+-----+-------+-----+---+
|week_num|parent|brand|channel|usage|lvl|
+--------+------+-----+-------+-----+---+
|    null|  null| null|   null|31658| 15|
|       2|  null| null|   null| 6000|  7|
|       2|     A| null|   null| 6000|  3|
|       2|     A|   A2|   null| 6000|  1|
|       2|     A|   A2|   A2TV| 3500|  0|
|       2|     A|   A2|  A2web| 2500|  0|
|       4|  null| null|   null|20700|  7|
|       4|     A| null|   null| 7500|  3|
|       4|     A|   A1|   null| 5500|  1|
|       4|     A|   A1|  A2app| 5500|  0|
|       4|     A|   AD|   null| 2000|  1|
|       4|     A|   AD|  ADapp| 2000|  0|
|       4|     B| null|   null|13200|  3|
|       4|     B|  B25|   null| 7600|  1|
|       4|     B|  B25| B25app| 7600|  0|
|       4|     B|  B26|   null| 5600|  1|
|       4|     B|  B26| B26app| 5600|  0|
|       5|  null| null|   null| 4958|  7|
|       5|     C| null|   null| 4958|  3|
|       5|     C|  c25|   null| 2658|  1|
+--------+------+-----+-------+-----+---+
only showing top 20 rows

The rest is pure cosmetic. Probably not a good idea to do that with spark. better do that in the restition tool you will use after.

agg_df = agg_df.withColumn("lvl", F.dense_rank().over(Window.orderBy("lvl")))

TOTAL = "Total"
agg_df = (
    agg_df.withColumn(
        "parent", F.when(F.col("lvl") == 4, TOTAL).otherwise(F.col("parent"))
    )
    .withColumn(
        "brand",
        F.when(F.col("lvl") == 3, TOTAL).otherwise(
            F.coalesce(F.col("brand"), F.lit(""))
        ),
    )
    .withColumn(
        "channel",
        F.when(F.col("lvl") == 2, TOTAL).otherwise(
            F.coalesce(F.col("channel"), F.lit(""))
        ),
    )
)

agg_df.where(F.col("lvl") != 5).orderBy(
    "week_num", F.col("lvl").desc(), "parent", "brand", "channel"
).drop("lvl").show(500)

+--------+------+-----+-------+-----+
|week_num|parent|brand|channel|usage|
+--------+------+-----+-------+-----+
|       2| Total|     |       | 6000|
|       2|     A|Total|       | 6000|
|       2|     A|   A2|  Total| 6000|
|       2|     A|   A2|   A2TV| 3500|
|       2|     A|   A2|  A2web| 2500|
|       4| Total|     |       |20700|
|       4|     A|Total|       | 7500|
|       4|     B|Total|       |13200|
|       4|     A|   A1|  Total| 5500|
|       4|     A|   AD|  Total| 2000|
|       4|     B|  B25|  Total| 7600|
|       4|     B|  B26|  Total| 5600|
|       4|     A|   A1|  A2app| 5500|
|       4|     A|   AD|  ADapp| 2000|
|       4|     B|  B25| B25app| 7600|
|       4|     B|  B26| B26app| 5600|
|       5| Total|     |       | 4958|
|       5|     C|Total|       | 4958|
|       5|     C|  c25|  Total| 2658|
|       5|     C|  c27|  Total| 1100|
|       5|     C|  c28|  Total| 1200|
|       5|     C|  c25| c25app| 2658|
|       5|     C|  c27| c27app| 1100|
|       5|     C|  c28| c26app| 1200|
+--------+------+-----+-------+-----+

Upvotes: 1

Related Questions