Anji
Anji

Reputation: 305

Sum the values on column using pyspark

I Have a scenario, where I have 2 tables one table with days and another table with values. So from the table that has days I need to sum the values of the another table same no of days. Dataframe

dataframe1
df1 = spark.createDataFrame(
     [
     ('ll',5)
     ('yy',6)
     ],
     ('x','days')
    )
  dataframe2 
  df = spark.createDataFrame(
        [
            ('ll','2020-01-05','1','10','50'),
    ('ll','2020-01-06','1','10'),
    ('ll','2020-01-07','1','10'),
    ('ll','2020-01-08','1','10'),
    ('ll','2020-01-09','1','10'),
    ('ll','2020-01-10','1','10'),
    ('ll','2020-01-11','1','20'),
    ('ll','2020-01-12','1','10'),
    ('ll','2020-01-05','2','30'),
    ('ll','2020-01-06','2','30'),
    ('ll','2020-01-07','2','30'),
    ('ll','2020-01-08','2','40'),
    ('ll','2020-01-09','2','30'),
    ('ll','2020-01-10','2','10'),
    ('ll','2020-01-11','2','10'),
    ('ll','2020-01-12','2','10'),
    ('yy','2020-01-05','1','20'),
    ('yy','2020-01-06','1','20'),
    ('yy','2020-01-07','1','20'),
    ('yy','2020-01-08','1','20'),
    ('yy','2020-01-09','1','20'),
    ('yy','2020-01-10','1','40'),
    ('yy','2020-01-11','1','20'),
    ('yy','2020-01-12','1','20'),
    ('yy','2020-01-05','2','40'),
    ('yy','2020-01-06','2','40'),
    ('yy','2020-01-07','2','40'),
    ('yy','2020-01-08','2','40'),
    ('yy','2020-01-09','2','40'),
    ('yy','2020-01-10','2','40'),
    ('yy','2020-01-11','2','60'),
    ('yy','2020-01-12','2','40')
        ],
        ('x','date','flag','value')
    )

                expected_dataframe = spark.createDataFrame(
        [
            ('ll','2020-01-05','1','10','50'),
    ('ll','2020-01-06','1','10','50'),
    ('ll','2020-01-07','1','10','60'),
    ('ll','2020-01-08','1','10','60'),
    ('ll','2020-01-09','1','10','50'),
    ('ll','2020-01-10','1','10','40'),
    ('ll','2020-01-11','1','20','30'),
    ('ll','2020-01-12','1','10','10'),
    ('ll','2020-01-05','2','30','170'),
    ('ll','2020-01-06','2','30','140'),
    ('ll','2020-01-07','2','30','120'),
    ('ll','2020-01-08','2','40','100'),
    ('ll','2020-01-09','2','30','60'),
    ('ll','2020-01-10','2','10','30'),
    ('ll','2020-01-11','2','10','20'),
    ('ll','2020-01-12','2','10','10'),
    ('yy','2020-01-05','1','20','140'),
    ('yy','2020-01-06','1','20','140'),
    ('yy','2020-01-07','1','20','140'),
    ('yy','2020-01-08','1','20','120'),
    ('yy','2020-01-09','1','20','100'),
    ('yy','2020-01-10','1','40','80'),
    ('yy','2020-01-11','1','20','40'),
    ('yy','2020-01-12','1','20','20'),
    ('yy','2020-01-05','2','40','240'),
    ('yy','2020-01-06','2','40','260'),
    ('yy','2020-01-07','2','40','260'),
    ('yy','2020-01-08','2','40','220'),
    ('yy','2020-01-09','2','40','180'),
    ('yy','2020-01-10','2','40','140'),
    ('yy','2020-01-11','2','60','100'),
    ('yy','2020-01-12','2','40','40')
        ],
        ('x','date','flag','value','result')

expected_results

    +---+----------+----+-----+------+
    |  x|      date|flag|value|result|
    +---+----------+----+-----+------+
    | ll|2020-01-05|   1|   10|    50|
    | ll|2020-01-06|   1|   10|    50|
    | ll|2020-01-07|   1|   10|    60|
    | ll|2020-01-08|   1|   10|    60|
    | ll|2020-01-09|   1|   10|    50|
    | ll|2020-01-10|   1|   10|    40|
    | ll|2020-01-11|   1|   20|    30|
    | ll|2020-01-12|   1|   10|    10|
    | ll|2020-01-05|   2|   30|   170|
    | ll|2020-01-06|   2|   30|   140|
    | ll|2020-01-07|   2|   30|   120|
    | ll|2020-01-08|   2|   40|   100|
    | ll|2020-01-09|   2|   30|    60|
    | ll|2020-01-10|   2|   10|    30|
    | ll|2020-01-11|   2|   10|    20|
    | ll|2020-01-12|   2|   10|    10|
    | yy|2020-01-05|   1|   20|   140|
    | yy|2020-01-06|   1|   20|   140|
    | yy|2020-01-07|   1|   20|   140|
    | yy|2020-01-08|   1|   20|   120|
    | yy|2020-01-09|   1|   20|   100|
    | yy|2020-01-10|   1|   40|    80|
    | yy|2020-01-11|   1|   20|    40|
    | yy|2020-01-12|   1|   20|    20|
    | yy|2020-01-05|   2|   40|   240|
    | yy|2020-01-06|   2|   40|   260|
    | yy|2020-01-07|   2|   40|   260|
    | yy|2020-01-08|   2|   40|   220|
    | yy|2020-01-09|   2|   40|   180|
    | yy|2020-01-10|   2|   40|   140|
    | yy|2020-01-11|   2|   60|   100|
    | yy|2020-01-12|   2|   40|    40|
    +---+----------+----+-----+------+

code

from pyspark.sql.window import Window
from pyspark.sql.functions import *
 df_join = df.join(df1,['x'],'inner').withColumn('date',to_date(col('date'),'yyyy-MM-dd'))
from pyspark.sql.window import Window
w1 =  Window.partitionBy('x','flag').orderBy(col['date'].desc())

So I need sum value column based on days column, i,e if days column is 5, I need to sum 5 rows of the values.

I had joined the two tables and using window function I tried to solve, but id didnt work out and not able figure out how to solve it. Can any show me the way how to solve it

Upvotes: 3

Views: 1013

Answers (1)

murtihash
murtihash

Reputation: 8410

First you could join on x, then create a row_number() over your rows, which will be used to single out wherever it is greater than days(turn them into nulls), then sum over a partitioned only window to broadcast your sum across all the rows.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w=Window().partitionBy("x","flag").orderBy(F.to_date("date","yyyy-dd-MM"))
w1=Window().partitionBy("x","flag")
df.join(df1, ['x'])\
  .withColumn("rowNum", F.row_number().over(w))\
  .withColumn("expected_result", F.sum(F.when(F.col("rowNum")>F.col("days")\
                                     ,F.lit(None)).otherwise(F.col("value")))\
                                      .over(w1)).drop("days","rowNum").show()

#+---+----------+----+-----+---------------+
#|  x|      date|flag|value|expected_result|
#+---+----------+----+-----+---------------+
#| ll|2020-01-05|   1|   10|           50.0|
#| ll|2020-01-06|   1|   10|           50.0|
#| ll|2020-01-07|   1|   10|           50.0|
#| ll|2020-01-08|   1|   10|           50.0|
#| ll|2020-01-09|   1|   10|           50.0|
#| ll|2020-01-10|   1|   10|           50.0|
#| ll|2020-01-11|   1|   10|           50.0|
#| ll|2020-01-12|   1|   10|           50.0|
#| ll|2020-01-05|   2|   30|          150.0|
#| ll|2020-01-06|   2|   30|          150.0|
#| ll|2020-01-07|   2|   30|          150.0|
#| ll|2020-01-08|   2|   30|          150.0|
#| ll|2020-01-09|   2|   30|          150.0|
#| ll|2020-01-10|   2|   10|          150.0|
#| ll|2020-01-11|   2|   10|          150.0|
#| ll|2020-01-12|   2|   10|          150.0|
#| yy|2020-01-05|   1|   20|          120.0|
#| yy|2020-01-06|   1|   20|          120.0|
#| yy|2020-01-07|   1|   20|          120.0|
#| yy|2020-01-08|   1|   20|          120.0|
#+---+----------+----+-----+---------------+
#only showing top 20 rows

UPDATE:

For Spark2.4+, you could use higher order functions transform and aggregate, after collect_list. I assumed data to be ordered as in the example provided, if thats not the case, then an extra step needs to be added to ensure that.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w=Window().partitionBy("x","flag")
w1=Window().partitionBy("x","flag").orderBy(F.to_date("date","yyyy-dd-MM"))

df.join(df1,['x'])\
  .withColumn("result", F.collect_list("value").over(w))\
  .withColumn("rowNum", F.row_number().over(w1)-1)\
  .withColumn("result", F.expr("""aggregate(transform(result,(x,i)->array(x,i)),0,(acc,x)-> \
                             IF((int(x[1])>=rowNum)and(int(x[1])<days+rowNum),int(x[0])+acc,acc))"""))\
  .drop("flag","rowNum","days").show()


#+---+----------+-----+------+
#|  x|      date|value|result|
#+---+----------+-----+------+
#| ll|2020-01-05|   10|    50|
#| ll|2020-01-06|   10|    50|
#| ll|2020-01-07|   10|    60|
#| ll|2020-01-08|   10|    60|
#| ll|2020-01-09|   10|    50|
#| ll|2020-01-10|   10|    40|
#| ll|2020-01-11|   20|    30|
#| ll|2020-01-12|   10|    10|
#| ll|2020-01-05|   30|   160|
#| ll|2020-01-06|   30|   140|
#| ll|2020-01-07|   30|   120|
#| ll|2020-01-08|   40|   100|
#| ll|2020-01-09|   30|    60|
#| ll|2020-01-10|   10|    30|
#| ll|2020-01-11|   10|    20|
#| ll|2020-01-12|   10|    10|
#| yy|2020-01-05|   20|   140|
#| yy|2020-01-06|   20|   140|
#| yy|2020-01-07|   20|   140|
#| yy|2020-01-08|   20|   120|
#| yy|2020-01-09|   20|   100|
#| yy|2020-01-10|   40|    80|
#| yy|2020-01-11|   20|    40|
#| yy|2020-01-12|   20|    20|
#| yy|2020-01-05|   40|   240|
#| yy|2020-01-06|   40|   260|
#| yy|2020-01-07|   40|   260|
#| yy|2020-01-08|   40|   220|
#| yy|2020-01-09|   40|   180|
#| yy|2020-01-10|   40|   140|
#| yy|2020-01-11|   60|   100|
#| yy|2020-01-12|   40|    40|
#+---+----------+-----+------+

Also, in your example, row number 9 should be 160, instead of 170.

Upvotes: 2

Related Questions