Aesir
Aesir

Reputation: 2473

Does pyspark DataFrame.randomSplit() return a stable split?

I am working on a problem with a smallish dataset. Overtime new data is collected and I would like to add this new data to my dataset. I have made a unique identifier in my current dataset and I have used randomSplit to split this into a train and test set:

train, test = unique_lifetimes_spark_df.select("lifetime_id").distinct().randomSplit(weights=[0.8, 0.2], seed=42)

If I now update my dataset and re-run the split, will it produce a stable split only expanding the existing groups or do I run the risk of "polluting" my groups?

As an example consider the dataset [A, B, C, D, E]. randomSplit splits this into two groups:

group 1: [A, B, C, D]

group 2: [E]

At some point I get an updated dataset [A, B, C, D, E, F, G] which I would like to incorporate into my modelling process. By using randomSplit, am I guaranteed that the returned split would never have [A, B, C, D] which originally appeared in group 1 in group 2 and vice versa [E] will never appear in group 1 as it originally appeared in group 2?

So the updated split should randomly decide where to put [F, G] and the rest should stay in their previously assigned groups.

Upvotes: 1

Views: 993

Answers (1)

Oli
Oli

Reputation: 10406

No, you do not have the guarantee that if the original dataset grows, the split will remain the same for the pre existing elements.

You can test it yourself:

scala> spark.range(5).randomSplit(Array(0.8, 0.2), seed=42).foreach(_.show)
+---+
| id|
+---+
|  1|
|  2|
+---+

+---+
| id|
+---+
|  0|
|  3|
|  4|
+---+


scala> spark.range(6).randomSplit(Array(0.8, 0.2), seed=42).foreach(_.show)
+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
+---+

+---+
| id|
+---+
|  5|
+---+

Upvotes: 4

Related Questions