Abhi
Abhi

Reputation: 123

Recursive calculation and transposing results for date range in PySpark

I have 10 years of data (kf below) in the below format, I am trying to sum data for each month for each department. A snapshot of data from April 2023(in Example below) should give me sale_1 as sum of sale_amt in Apr 2023, sale_2 as sum of sale_amt in 1 month prior or March 2023 and so on. For no values I populate null. To create a master table from this 10 year data I use the below PySpark code 10 times for 10 years but I am confused on how to transpose or vertically stack the calculations so that if someone queries my result they can get snapshots acccording to the month and year they filter on.(in Example). Any suggestions on doing this more efficiently in PySpark? TIA!

kf=main file with 10 years data
mf=kf.dropDuplicates(['dept_ID'])
gf=kf
month_list={'1','2','3','4','5','6','7','8','9','10','11','12'}
window = Winddow().partitionBy("dept_ID")
for i in month_list:
   df = gf.filter(gf.month==i).withColumn("sale_"+i, sum(coalesce('sale_amt'), lit(0))).over(window))
   df = df.dropDuplicates(['dept_ID'])
   mf = mf.join(df, mf.dept_ID==df.dept_ID, 'left').drop(df.dept_ID)

kf:

dept_ID sale_amt    sale_date   sale_month  sale_year
1   10  4/1/2023    4   2023
1   60  4/1/2023    4   2023
1   30  3/1/2023    3   2023
1   15  3/1/2023    3   2023
1   12  2/1/2023    2   2023
1   10  1/1/2023    1   2023
1   90  1/1/2023    1   2023
1   40  12/1/2022   12  2022
1   40  11/1/2022   11  2022
1   75  10/1/2022   10  2022
1   30  9/1/2022    9   2022
1   50  9/1/2022    9   2022
1   25  8/1/2022    8   2022
1   40  8/1/2022    8   2022
1   70  7/1/2022    7   2022
1   80  5/1/2022    5   2022
1   10  5/1/2022    5   2022
1   45  4/1/2022    4   2022
1   15  4/1/2022    4   2022
2   10  4/1/2023    4   2023
2   60  4/1/2023    4   2023
2   30  3/1/2023    3   2023
2   15  3/1/2023    3   2023
2   12  2/1/2023    2   2023
2   10  1/1/2023    1   2023
2   90  1/1/2023    1   2023
2   40  12/1/2022   12  2023
2   40  11/1/2022   11  2023
2   80  10/1/2022   10  2023
2   30  9/1/2022    9   2023
3   50  9/1/2022    9   2023
3   25  8/1/2022    8   2023
3   40  8/1/2022    8   2023
3   70  7/1/2022    7   2023
3   80  5/1/2022    5   2023
3   10  5/1/2022    5   2023


Expected result:

enter image description here

Upvotes: 1

Views: 186

Answers (2)

stack0114106
stack0114106

Reputation: 8781

Here is another spark-sql solution.

Assuming that OP wants the snapshot for 12 month backwards for a given year_month_number as the input.

A reference view is created with null amt values for all the yearmonth combinations from year starting 2010 to 2025 i.e 12*15.

val df = spark.read.format("csv").option("header","true").option("inferSchema","true").load("sale.csv")
val df2=df.withColumn("dt2",to_date(col("sale_date"),"MM/dd/yy"))
df2.createOrReplaceTempView("sale")

val ref_df = spark.sql(" select add_months('2010-01-01',id) yyyymm, cast(null as decimal(15,3)) amt from range(12*15)  order by 1 ")
ref_df.createOrReplaceTempView("ref")

spark.conf.set("spark.sql.crossJoin.enabled","true") // to allow cross join 
// Inputs
val year_month=202304
val month=year_month.toString.drop(4)
val year=year_month.toString.take(4)
val yearp=year.toInt-1
val dfs = spark.sql(s""" with t1 (select dept_id, date_format(dt2,'yyyyMM') yyyymm,  sum(sale_amt) sale_amt 
                                            from sale where year(dt2) in (${year},${yearp}) group by 1,2
                                   union all 
                                    select dept_id, date_format(yyyymm,'yyyyMM') yyyymm, amt from ref , (select distinct(dept_id) dept_id from sale) where year(yyyymm) in (${year},${yearp})
                                 )
 select dept_id, yyyymm, sale_amt amt from t1 
 where months_between(to_date('${year_month}01','yyyyMMdd'),to_date(yyyymm||'01','yyyyMMdd')) < 12
 and   months_between(to_date('${year_month}01','yyyyMMdd'),to_date(yyyymm||'01','yyyyMMdd')) >= 0   
""")

dfs.show(50,false)

dfs.createOrReplaceTempView("sale2")

val vl_ym = spark.sql("select collect_set(yyyymm) from sale2").as[Seq[String]].first.map( x => "'"+x+"'").sorted.reverse.mkString(",")
val dfs2 = spark.sql(s"""
select * from (select dept_id, yyyymm, sum(amt) amt from sale2 group by 1,2 ) t
 PIVOT ( 
  sum(amt) as amt
  FOR yyyymm IN ( ${vl_ym} )
  ) 
""")

dfs2.withColumn("year",lit(year)).withColumn("month",lit(month)).orderBy("dept_id").show(false)

The last output being

+-------+------+------+------+-------+------+------+------+------+------+------+------+------+----+-----+
|dept_id|202304|202303|202302|202301 |202212|202211|202210|202209|202208|202207|202206|202205|year|month|
+-------+------+------+------+-------+------+------+------+------+------+------+------+------+----+-----+
|1      |70.000|45.000|12.000|100.000|40.000|40.000|75.000|80.000|65.000|70.000|null  |90.000|2023|04   |
|2      |70.000|45.000|12.000|100.000|40.000|40.000|80.000|30.000|null  |null  |null  |null  |2023|04   |
|3      |null  |null  |null  |null   |null  |null  |null  |50.000|65.000|70.000|null  |90.000|2023|04   |
+-------+------+------+------+-------+------+------+------+------+------+------+------+------+----+-----+

You can rename 202304,202303.. to sale_1, sale_2.. etc

Pyspark part for the concatenated string:

>>> seqString=spark.sql("select array_sort(collect_set(yyyymm)) from ( select date_format(add_months(date'2023-01-01',id),'yyyyMM') yyyymm from range(12)) ").collect()[0][0]
>>> sorted_seq = sorted(seqString, reverse=True)
>>> addPrefix=''.join(map(lambda x: "'" + x + "'",sorted_seq))
>>> addPrefix
"'202312''202311''202310''202309''202308''202307''202306''202305''202304''202303''202302''202301'"
>>>

Upvotes: 1

Lamanus
Lamanus

Reputation: 13581

I think this is a simple pivot problem.

year = '2023'
month = '04'

df.withColumn('is_current', f.expr(f'if(sale_year = {year} and sale_month = {month}, true, false)')) \
  .withColumn('order', f.expr(f'if(sale_year = {year} and sale_month = {month}, 1, dense_rank() over (partition by dept_ID, is_current order by sale_year desc, sale_month desc) + 1)')) \
  .filter('order <= 12') \
  .withColumn('col_name', f.concat(f.lit('sale_'), f.col('order'))) \
  .groupBy('dept_ID') \
  .pivot('col_name') \
  .agg(f.sum('sale_amt')) \
  .withColumn('year', f.lit(year)) \
  .withColumn('month', f.lit(month)) \
  .show()

+-------+------+-------+-------+------+------+------+------+------+------+------+------+----+-----+
|dept_ID|sale_1|sale_10|sale_12|sale_2|sale_3|sale_4|sale_5|sale_6|sale_7|sale_8|sale_9|year|month|
+-------+------+-------+-------+------+------+------+------+------+------+------+------+----+-----+
|      1|    70|     70|     90|    45|    12|   100|    40|    40|    75|    80|    65|2023|   04|
+-------+------+-------+-------+------+------+------+------+------+------+------+------+----+-----+

Upvotes: 1

Related Questions