Reputation: 3929
I found a very strange behavior with pyspark when I use randomSplit
. I have a column is_clicked
that takes values 0
or 1
and there are way more zeros than ones. After random split I would expect the data would be uniformly distributed. But instead, I see that the first rows in the splits are all is_cliked=1
, followed by rows that are all is_clicked=0
. You can see that number of clicks in the original dataframe df
is 9 out of 1000 (which is what I expect). But after random split the number of clicks is 1000 out of 1000. If I take more rows I will see that it's all going to be is_clicked=1
until there are no more columns like this, and then it will be followed by rows is_clicked=0
.
Anyone knows why there is distribution change after random split? How can I make is_clicked
be uniformly distributed after split?
Upvotes: 1
Views: 1565
Reputation: 3929
So indeed pyspark does sort the data, when does randomSplit. Here is a quote from the code:
It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its constituent partitions each time a split is materialized which could result in overlapping splits. To prevent this, we explicitly sort each input partition to make the ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out from the sort order.
The solution to this either reshuffle the data after the split or just use filter
instead of randomSplit
:
Solution 1:
df = df.withColumn('rand', sf.rand(seed=42)).orderBy('rand')
df_train, df_test = df.randomSplit([0.5, 0.5])
df_train.orderBy('rand')
Solution 2:
df_train = df.filter(df.rand < 0.5)
df_test = df.filter(df.rand >= 0.5)
Here is a blog post with more details.
Upvotes: 2