Prince Bhatti
Prince Bhatti

Reputation: 5031

Split Spark dataframe by row index

I want to split a data-frame in row-wise order. If there are 100 rows, then desired split into 4 equal data-frames should have indices 0-24, 25-49, 50-74, and 75-99, respectively.

The only pre-defined function available is randomSplit. But randomSplit randomizes the data before splitting. Another way I think of is to find the count of data using count reduction operation and then keep extracting the data using take but it is very expensive. Is there any other way to achieve the above while maintaining the same order?

Upvotes: 2

Views: 6129

Answers (1)

ktdrv
ktdrv

Reputation: 3673

You can use monotonically_increasing_id to get the row number (if you don't already have it) and then ntile over a row number window to split into however many partitions you want:

from pyspark.sql.window import Window
from pyspark.sql.functions import monotonically_increasing_id, ntile

values = [(str(i),) for i in range(100)]
df = spark.createDataFrame(values, ('value',))

def split_by_row_index(df, num_partitions=4):
    # Let's assume you don't have a row_id column that has the row order
    t = df.withColumn('_row_id', monotonically_increasing_id())
    # Using ntile() because monotonically_increasing_id is discontinuous across partitions
    t = t.withColumn('_partition', ntile(num_partitions).over(Window.orderBy(t._row_id))) 
    return [t.filter(t._partition == i+1).drop('_row_id', '_partition') for i in range(partitions)]

[i.collect() for i in split_by_row_index(df)]

Upvotes: 7

Related Questions