Secil Sozuer
Secil Sozuer

Reputation: 25

Pyspark - Creating a dataframe by user defined aggregate function and pivoting

I need to write a user defined aggregate function that captures the number of days between previous discharge_date and following admit_date for each consecutive visits.

I will also need to pivot on the "PERSON_ID" values.

I have the following input_df :

input_df :

+---------+----------+--------------+
|PERSON_ID|ADMIT_DATE|DISCHARGE_DATE|
+---------+----------+--------------+
|      111|2018-03-15|    2018-03-16|
|      333|2018-06-10|    2018-06-11|
|      111|2018-03-01|    2018-03-02|
|      222|2018-12-01|    2018-12-02|
|      222|2018-12-05|    2018-12-06|
|      111|2018-03-30|    2018-03-31|
|      333|2018-06-01|    2018-06-02|
|      333|2018-06-20|    2018-06-21|
|      111|2018-01-01|    2018-01-02|
+---------+----------+--------------+

First, I need to group by each person and sort the corresponding rows by ADMIT_DATE. That would yield "input_df2".

input_df2:

+---------+----------+--------------+
|PERSON_ID|ADMIT_DATE|DISCHARGE_DATE|
+---------+----------+--------------+
|      111|2018-01-01|    2018-01-03|
|      111|2018-03-01|    2018-03-02|
|      111|2018-03-15|    2018-03-16|
|      111|2018-03-30|    2018-03-31|
|      222|2018-12-01|    2018-12-02|
|      222|2018-12-05|    2018-12-06|
|      333|2018-06-01|    2018-06-02|
|      333|2018-06-10|    2018-06-11|
|      333|2018-06-20|    2018-06-21|
+---------+----------+--------------+

The desired output_df :

+------------------+-----------------+-----------------+----------------+
|PERSON_ID_DISTINCT| FIRST_DIFFERENCE|SECOND_DIFFERENCE|THIRD_DIFFERENCE|
+------------------+-----------------+-----------------+----------------+
|               111| 1 month 26 days |          13 days|         14 days|
|               222|           3 days|              NAN|             NAN|
|               333|           8 days|           9 days|             NAN|
+------------------+-----------------+-----------------+----------------+

I know the maximum number a person appears in my input_df, so I know how many columns should be created by :

print input_df.groupBy('PERSON_ID').count().sort('count', ascending=False).show(5)

Thanks a lot in advance,

Upvotes: 1

Views: 267

Answers (1)

pault
pault

Reputation: 43534

You can use pyspark.sql.functions.datediff() to compute the difference between two dates in days. In this case, you just need to compute the difference between the current row's ADMIT_DATE and the previous row's DISCHARGE_DATE. You can do this by using pyspark.sql.functions.lag() over a Window.

For example, we can compute the duration between visits in days as a new column DURATION.

import pyspark.sql.functions as f
from pyspark.sql import Window

w = Window.partitionBy('PERSON_ID').orderBy('ADMIT_DATE')
input_df.withColumn(
        'DURATION',
        f.datediff(f.col('ADMIT_DATE'), f.lag('DISCHARGE_DATE').over(w))
    )\
    .withColumn('INDEX', f.row_number().over(w)-1)\
    .sort('PERSON_ID', 'INDEX')\
    .show()
#+---------+----------+--------------+--------+-----+
#|PERSON_ID|ADMIT_DATE|DISCHARGE_DATE|DURATION|INDEX|
#+---------+----------+--------------+--------+-----+
#|      111|2018-01-01|    2018-01-02|    null|    0|
#|      111|2018-03-01|    2018-03-02|      58|    1|
#|      111|2018-03-15|    2018-03-16|      13|    2|
#|      111|2018-03-30|    2018-03-31|      14|    3|
#|      222|2018-12-01|    2018-12-02|    null|    0|
#|      222|2018-12-05|    2018-12-06|       3|    1|
#|      333|2018-06-01|    2018-06-02|    null|    0|
#|      333|2018-06-10|    2018-06-11|       8|    1|
#|      333|2018-06-20|    2018-06-21|       9|    2|
#+---------+----------+--------------+--------+-----+

Notice, I also added an INDEX column using pyspark.sql.functions.row_number(). We can just filter for INDEX > 0 (because the first value will always be null) and then just pivot the DataFrame:

input_df.withColumn(
        'DURATION',
        f.datediff(f.col('ADMIT_DATE'), f.lag('DISCHARGE_DATE').over(w))
    )\
    .withColumn('INDEX', f.row_number().over(w) - 1)\
    .where('INDEX > 0')\
    .groupBy('PERSON_ID').pivot('INDEX').agg(f.first('DURATION'))\
    .sort('PERSON_ID')\
    .show()
#+---------+---+----+----+
#|PERSON_ID|  1|   2|   3|
#+---------+---+----+----+
#|      111| 58|  13|  14|
#|      222|  3|null|null|
#|      333|  8|   9|null|
#+---------+---+----+----+

Now you can rename the columns to whatever you desire.

Note: This assumes that ADMIT_DATE and DISCHARGE_DATE are of type date.

input_df.printSchema()
#root
# |-- PERSON_ID: long (nullable = true)
# |-- ADMIT_DATE: date (nullable = true)
# |-- DISCHARGE_DATE: date (nullable = true)

Upvotes: 1

Related Questions