Gary Chan Chi Hang
Gary Chan Chi Hang

Reputation: 89

Is there a way to partition/group by data where sum of column values per each group is under a limit?

i want to partition/group rows for every group of size <= limit

for example, if i have:

+--------+----------+
|      id|      size|
+--------+----------+
|       1|         3|
|       2|         6|
|       3|         8|
|       4|         5|
|       5|         7|
|       6|         7|
+--------+----------+

and i want to group rows by every size <=10, result would be:

+--------+----------+----------+
|      id|      size|     group|
+--------+----------+----------+
|       1|         3|         0|
|       2|         6|         0|
|       3|         8|         1|
|       4|         5|         2|
|       5|         7|         3|
|       6|         7|         4|
+--------+----------+----------+

another example, by every size <=13,

+--------+----------+----------+
|      id|      size|     group|
+--------+----------+----------+
|       1|         3|         0|
|       2|         6|         0|
|       3|         8|         1|
|       4|         5|         1|
|       5|         7|         2|
|       6|         7|         3|
+--------+----------+----------+

not even quite sure where to start with, have looked into window function, reduce function, user define aggregate function or adding addition columns (e.g. accumulate sum etc)..

the original task was to group request payload so that under a size limit they can be grouped into a single request.

Upvotes: 0

Views: 42

Answers (1)

user238607
user238607

Reputation: 2468

Here's an example. Since exact correctness is not required for you application, which implies we can have approximate correctness.

First we group the rows into suitable sizes. On those groups, we use pandas_udf grouped map to find subgroups which give you the optimum number of rows with < payload_limit.

Here's a possible example.

import math
from pyspark.sql.functions import avg, floor, rand, pandas_udf, PandasUDFType
from pyspark.sql.functions import col, sum, row_number, monotonically_increasing_id, count
from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("Example") \
    .getOrCreate()

data = [
    (1, 3),
    (2, 6),
    (3, 8),
    (4, 5),
    (5, 7),
    (6, 7),
    (11, 3),
    (12, 6),
    (13, 8),
    (14, 5),
    (15, 7),
    (16, 7),
    (21, 3),
    (22, 6),
    (23, 8),
    (24, 5),
    (25, 7),
    (26, 7)

]
df = spark.createDataFrame(data, ["id", "size"])

# Find the average size
avg_size = df.select(avg("size")).collect()[0][0]

payload_limit = 20
rows_per_group = math.floor(payload_limit / avg_size)
print(f"{avg_size=}")
print(f"{rows_per_group=}")
print(f"{df.count()=}")
nums_of_group = math.ceil(df.count() / rows_per_group)
print(f"{nums_of_group=}")

df = df.withColumn("random_group_id", floor(rand() * nums_of_group))
distinct_group_ids = df.select(col("random_group_id")).distinct()
distinct_group_ids.show(n=100, truncate=False)
print(f"{distinct_group_ids.count()}")

grouped_counts = df.groupby(col("random_group_id")).agg(count("*"))
grouped_counts.show(n=100, truncate=False)

df.show(n=100, truncate=False)

result_schema = StructType([
    StructField("id", IntegerType()),
    StructField("size", IntegerType()),
    StructField("random_group_id", IntegerType()),
    StructField("sub_group", IntegerType()),
])


@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
def group_by_limit(pdf):
    limit = payload_limit
    group_col = "sub_group"

    print("before")
    print(pdf)

    # Calculate the cumulative sum of sizes within each random group
    pdf["cum_size"] = pdf.groupby("random_group_id")["size"].cumsum()

    # Assign group numbers based on the cumulative sum and limit
    pdf[group_col] = (pdf["cum_size"]) // limit


    # Drop the cumulative sum column
    pdf = pdf.drop("cum_size", axis=1)

    print("after")
    print(pdf)

    return pdf


# Apply the pandas UDF to the DataFrame
grouped_df = df.groupby("random_group_id").apply(group_by_limit)


grouped_df.show()

## Verify correctness of our algorithm.
result_grouped = grouped_df.groupBy("random_group_id", "sub_group").agg(sum("size"))
result_grouped.orderBy("random_group_id", "sub_group").show(n=100, truncate=False)

Output :

avg_size=6.0
rows_per_group=3
df.count()=18
nums_of_group=6
+---------------+
|random_group_id|
+---------------+
|5              |
|3              |
|1              |
|2              |
|4              |
+---------------+

5
+---------------+--------+
|random_group_id|count(1)|
+---------------+--------+
|5              |3       |
|3              |4       |
|1              |5       |
|2              |2       |
|4              |4       |
+---------------+--------+

+---+----+---------------+
|id |size|random_group_id|
+---+----+---------------+
|1  |3   |5              |
|2  |6   |3              |
|3  |8   |1              |
|4  |5   |5              |
|5  |7   |1              |
|6  |7   |1              |
|11 |3   |3              |
|12 |6   |3              |
|13 |8   |5              |
|14 |5   |2              |
|15 |7   |4              |
|16 |7   |1              |
|21 |3   |4              |
|22 |6   |2              |
|23 |8   |3              |
|24 |5   |1              |
|25 |7   |4              |
|26 |7   |4              |
+---+----+---------------+

+---+----+---------------+---------+
| id|size|random_group_id|sub_group|
+---+----+---------------+---------+
|  3|   8|              1|        0|
|  5|   7|              1|        0|
|  6|   7|              1|        1|
| 16|   7|              1|        1|
| 24|   5|              1|        1|
| 14|   5|              2|        0|
| 22|   6|              2|        0|
|  2|   6|              3|        0|
| 11|   3|              3|        0|
| 12|   6|              3|        0|
| 23|   8|              3|        1|
| 15|   7|              4|        0|
| 21|   3|              4|        0|
| 25|   7|              4|        0|
| 26|   7|              4|        1|
|  1|   3|              5|        0|
|  4|   5|              5|        0|
| 13|   8|              5|        0|
+---+----+---------------+---------+

+---------------+---------+---------+
|random_group_id|sub_group|sum(size)|
+---------------+---------+---------+
|1              |0        |15       |
|1              |1        |19       |
|2              |0        |11       |
|3              |0        |15       |
|3              |1        |8        |
|4              |0        |17       |
|4              |1        |7        |
|5              |0        |16       |
+---------------+---------+---------+

Upvotes: 1

Related Questions