Reputation: 2947
I have a pyspark DataFrame like the following:
+--------+--------+-----------+
| col1 | col2 | groupId |
+--------+--------+-----------+
| val11 | val21 | 0 |
| val12 | val22 | 1 |
| val13 | val23 | 2 |
| val14 | val24 | 0 |
| val15 | val25 | 1 |
| val16 | val26 | 1 |
+--------+--------+-----------+
Each row has a groupId
and multiple rows can have the same groupId
.
I want to randomly split this data into two datasets. But all the data having a particular groupId
must be in one of the splits.
This means that if d1.groupId = d2.groupId
, then d1
and d2
are in the same split.
For example:
# Split 1:
+--------+--------+-----------+
| col1 | col2 | groupId |
+--------+--------+-----------+
| val11 | val21 | 0 |
| val13 | val23 | 2 |
| val14 | val24 | 0 |
+--------+--------+-----------+
# Split 2:
+--------+--------+-----------+
| col1 | col2 | groupId |
+--------+--------+-----------+
| val12 | val22 | 1 |
| val15 | val25 | 1 |
| val16 | val26 | 1 |
+--------+--------+-----------+
What is the good way to do it on PySpark? Can I use the randomSplit
method somehow?
Upvotes: 3
Views: 2423
Reputation: 43494
You can use randomSplit
to split just the distinct groupId
s, and then use the results to split the source DataFrame using join
.
For example:
split1, split2 = df.select("groupId").distinct().randomSplit(weights=[0.5, 0.5], seed=0)
split1.show()
#+-------+
#|groupId|
#+-------+
#| 1|
#+-------+
split2.show()
#+-------+
#|groupId|
#+-------+
#| 0|
#| 2|
#+-------+
Now join these back to the original DataFrame:
df1 = df.join(split1, on="groupId", how="inner")
df2 = df.join(split2, on="groupId", how="inner")
df1.show()
3+-------+-----+-----+
#|groupId| col1| col2|
#+-------+-----+-----+
#| 1|val12|val22|
#| 1|val15|val25|
#| 1|val16|val26|
#+-------+-----+-----+
df2.show()
#+-------+-----+-----+
#|groupId| col1| col2|
#+-------+-----+-----+
#| 0|val11|val21|
#| 0|val14|val24|
#| 2|val13|val23|
#+-------+-----+-----+
Upvotes: 4