tomruarol
tomruarol

Reputation: 69

Pyspark window function with conditions to round number of travelers

I am using Pyspark and I would like to create a function which performs the following operation:

Given data describing the transactions of train users:

+----+----------+--------+-----+
|date|total_trav|num_trav|order|
+----+----------+--------+-----+
|   1|         9|     2.7|    1|
|   1|         9|     1.3|    2|
|   1|         9|     1.3|    3|
|   1|         9|     1.3|    4|
|   1|         9|     1.2|    5|
|   1|         9|     1.1|    6|
|   2|         9|     2.7|    1|
|   2|         9|     1.3|    2|
|   2|         9|     1.3|    3|
|   2|         9|     1.3|    4|
|   2|         9|     1.2|    5|
|   2|         9|     1.1|    6|
+----+----------+--------+-----+

I would like to round the numbers of the num_trav column based on the order given in the order column, while grouping by date to obtain the trav_res column. The logic behind it would be something like:

For example, let's consider this result dataframe and see how the trav_res column is formed:

+----+----------+--------+-----+--------+
|date|total_trav|num_trav|order|trav_res|
+----+----------+--------+-----+--------+
|   1|         9|     2.7|    1|       3|
|   1|         9|     1.3|    2|       2|
|   1|         9|     1.3|    3|       2|
|   1|         9|     1.3|    4|       2|
|   1|         9|     1.2|    5|       0|
|   1|         9|     1.1|    6|       0|
|   2|         9|     2.7|    1|       3|
|   2|         9|     1.3|    2|       2|
|   2|         9|     1.3|    3|       2|
|   2|         9|     1.3|    4|       2|
|   2|         9|     1.2|    5|       0|
|   2|         9|     1.1|    6|       0|
+----+----------+--------+-----+--------+

In the example above, when you group by date, you will have 2 groups which the max amount of travelers is 9 (total_trav column). For group 1 for example yo will start rounding the num_trav=2.7 to 3 (trav_res column), then the num_trav=1.3 to 2, then num_trav=1.3 to 2, the num_trav=1.3 to 2 (this is following the order given), and then for the next ones you have no travelers left, so it doesn't really matter the number they have as there are no travelers left, so they will get trav_res=0 in both cases.

I have tried some udf functions, but thy seem not to do the job.

Upvotes: 1

Views: 533

Answers (2)

tomruarol
tomruarol

Reputation: 69

The solution is based on @AnnaK. answer, with a little addition to it. This way it takes into accountthat the total number of travelers (total_trav) has to be used, not more, not less.

# create ceiling column
df = df_j_test_res.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil")
                     ).when((F.col('num_trav_ceil_cumsum')-F.col('total_trav')>0) & ((F.col('num_trav_ceil_cumsum')-F.col('total_trav')<=1)),
                      1)
              .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))

Upvotes: 0

Anna K.
Anna K.

Reputation: 1530

You can first apply F.ceil to all rows in num_trav, then create cumsum column based on ceiling values, and then set the ceiling values to zero when cumsum exceeds total_trav as in the code below

# create dataframe
import pyspark.sql.functions as F
from pyspark.sql import Window

data = [(1, 9, 2.7, 1),
        (1, 9, 1.3, 2),
        (1, 9, 1.3, 3),
        (1, 9, 1.3, 4),
        (1, 9, 1.2, 5),
        (1, 9, 1.1, 6),
        (2, 9, 2.7, 1),
        (2, 9, 1.3, 2),
        (2, 9, 1.3, 3),
        (2, 9, 1.3, 4),
        (2, 9, 1.2, 5),
        (2, 9, 1.1, 6)]

df = spark.createDataFrame(data, schema=["date", "total_trav", "num_trav", "order"])

# create ceiling column
df = df.withColumn("num_trav_ceil", F.ceil("num_trav"))

# create cumulative sum column
w = Window.partitionBy("date").orderBy("order")
df = df.withColumn("num_trav_ceil_cumsum", F.sum("num_trav_ceil").over(w))

# impose 0 in trav_res when cumsum exceeds total_trav
df = (df
  .withColumn("trav_res", 
               F.when(F.col("num_trav_ceil_cumsum")<=F.col("total_trav"), 
               F.col("num_trav_ceil"))
               .otherwise(0))
  .select("date", "total_trav", "num_trav", "order", "trav_res"))

Upvotes: 2

Related Questions