Distribute the value of a Spark Dataframe row proportionately to other rows

When I have the common value in the team column, I have to share this common proportionally between the teams that make part of the same sale (id_sales).

|id_sales       |team            |price           |
|101            |Data Engineering|             200|
|102            |       Front-End|             300|
|103            |  Infrastructure|             100|
|103            |        Software|             200|
|103            |          Commum|             800|
|104            |    Data Science|             500|

For example: In the table above I have within id_sales = 103 the Common value, so I have to calculate how much the Common is worth for each team: - Infrastructure: 100 - Software: 200 So for Infrastructure it's 1/3 * (800) and for Software it's 2/3 * (800) so at the end my table will be like this:

|id_sales       |team            |price           |
|101            |Data Engineering|             200|
|102            |       Front-End|             300|
|103            |  Infrastructure|          366,66|
|103            |        Software|          733,66|
|104            |    Data Science|             500|

Somebody could me give some idea or hint please? The hint can be in python or scala (Spark 2.4).

The code to create this table:


spark_df = spark.createDataFrame( \
[ \
  ("101", "Data Engineering", "200"),
  ("102", "Front-End", "300"),
  ("103", "Infrastructure", "100"),
  ("103", "Software", "200"),
  ("103", "Commum", "800"),
  ("104", "Data Science", "500") \
["id_sales", "team", "price"])

Spark Scala

val spark_df = Seq(
  ("101", "Data Engineering", "200"),
  ("102", "Front-End", "300"),
  ("103", "Infrastructure", "100"),
  ("103", "Software", "200"),
  ("103", "Commum", "800"),
  ("104", "Data Science", "500")
).toDF("id_sales", "team", "price")

Thanks :)

Parvez Patel
Try this:

scala> val df = Seq(
     |   ("101", "Data Engineering", "200"),
     |   ("102", "Front-End", "300"),
     |   ("103", "Infrastructure", "100"),
     |   ("103", "Software", "200"),
     |   ("103", "Common", "800"),
     |   ("104", "Data Science", "500")
     | ).toDF("id_sales", "team", "price")
df: org.apache.spark.sql.DataFrame = [id_sales: string, team: string ... 1 more field]

scala> df.show
|id_sales|            team|price|
|     101|Data Engineering|  200|
|     102|       Front-End|  300|
|     103|  Infrastructure|  100|
|     103|        Software|  200|
|     103|          Common|  800|
|     104|    Data Science|  500|

scala> val commonDF = df.filter("team='Common'")
commonDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id_sales: string, team: string ... 1 more field]

scala> import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.Window

scala> val ww = Window.partitionBy("id_sales")
ww: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@43324745

scala> val finalDF = df.as("main").filter("team<>'Common'").withColumn("weight",col("price")/sum("price").over(ww)).join(commonDF.as("common"), Seq("id_sales"),"left").withColumn("updated_price",when(col("common.price").isNull,df("price")).otherwise(df("price")+col("weight")*col("common.price"))).select($"id_sales",$"main.team",$"updated_price".as("price"))
finalDF: org.apache.spark.sql.DataFrame = [id_sales: string, team: string ... 1 more field]

scala> finalDF.show
|id_sales|            team|             price|
|     101|Data Engineering|               200|
|     104|    Data Science|               500|
|     102|       Front-End|               300|
|     103|        Software| 733.3333333333333|
|     103|  Infrastructure|366.66666666666663|

