Averell
Averell

Reputation: 843

Split an array column into chunks of max size

I have a DataFrame with one column of array[string] type.

scala> df.printSchema
root
 |-- user: string (nullable = true) ### this is an unique key 
 |-- items: array (nullable = true)
 |    |-- element: string (containsNull = true)

Due to some limitations on the consumer's side, I need to limit the number of elements in the items column, e.g: to maximum 1000 elements. The outcome DataFrame would have the same schema, except there's no uniqueness on the items column anymore. For example, with max elements = 3:

Input DataFrame:

+----+----------------------+
|user|items                 |
+----+----------------------+
|u1  |[a, b, cc, d, e, f, g]|
|u2  |[h, ii]               |
|u3  |[j, kkkk, m, nn, o]   |
+----+----------------------+

Output DataFrame:

+----+------------+
|user|items       |
+----+------------+
|u1  |[a, f, g]   |
|u1  |[b, cc, d]  |
|u1  |[e]         |
|u2  |[h, ii]     |
|u3  |[j, nn, m]  |
|u3  |[kkkk, o]   |
+----+------------+

The order of items is not important. The value of each item is just a string of alphanumeric chars, but the size of each item is not fixed.

Performance is not an issue, the DataFrame is small but we need the solution in SparkSQL.

Upvotes: 1

Views: 1197

Answers (1)

mazaneicha
mazaneicha

Reputation: 9427

This can be worked out without higher-order functions, in three easy steps:

  1. posexplode the arrays of items
  2. take integral part from dividing item's pos by N, the desired number of elements in subarrays
  3. collect_list new arrays grouping by user and pos.

For N=3:

    >>> df = spark.createDataFrame([
    ... {'user':'u1','items':['a', 'b', 'cc', 'd', 'e', 'f', 'g']},
    ... {'user':'u2','items':['h', 'ii']},
    ... {'user':'u3','items':['j', 'kkkk', 'm', 'nn', 'o']}
    ... ])
    >>> from pyspark.sql.functions import *
    >>> df1 = df.select(posexplode(df.items),df.user)
    >>> df2 = df1.select(floor(df1.pos/3).alias('pos'),df1.col.alias('item'),df1.user)
    >>> df3 = df2.groupby([df2.user,df2.pos]).agg(collect_list(df2.item)).drop('pos')
    >>> df3.show(truncate=False)
    +----+------------------+                                                       
    |user|collect_list(item)|
    +----+------------------+
    |u2  |[h, ii]           |
    |u1  |[a, b, cc]        |
    |u1  |[d, e, f]         |
    |u1  |[g]               |
    |u3  |[nn, o]           |
    |u3  |[j, kkkk, m]      |
    +----+------------------+
    
    >>> 

Upvotes: 4

Related Questions