ywat
ywat

Reputation: 2947

Randomly Split DataFrame by Unique Values in One Column

I have a pyspark DataFrame like the following:

+--------+--------+-----------+
| col1   |  col2  |  groupId  |
+--------+--------+-----------+
| val11  | val21  |   0       |
| val12  | val22  |   1       |
| val13  | val23  |   2       |
| val14  | val24  |   0       |
| val15  | val25  |   1       |
| val16  | val26  |   1       |
+--------+--------+-----------+    

Each row has a groupId and multiple rows can have the same groupId.

I want to randomly split this data into two datasets. But all the data having a particular groupId must be in one of the splits.

This means that if d1.groupId = d2.groupId, then d1 and d2 are in the same split.

For example:

# Split 1:

+--------+--------+-----------+
| col1   |  col2  |  groupId  |
+--------+--------+-----------+
| val11  | val21  |   0       |
| val13  | val23  |   2       |
| val14  | val24  |   0       |
+--------+--------+-----------+

# Split 2:
+--------+--------+-----------+
| col1   |  col2  |  groupId  |
+--------+--------+-----------+
| val12  | val22  |   1       |
| val15  | val25  |   1       |
| val16  | val26  |   1       |
+--------+--------+-----------+

What is the good way to do it on PySpark? Can I use the randomSplit method somehow?

Upvotes: 3

Views: 2423

Answers (1)

pault
pault

Reputation: 43494

You can use randomSplit to split just the distinct groupIds, and then use the results to split the source DataFrame using join.

For example:

split1, split2 = df.select("groupId").distinct().randomSplit(weights=[0.5, 0.5], seed=0)
split1.show()
#+-------+
#|groupId|
#+-------+
#|      1|
#+-------+

split2.show()
#+-------+
#|groupId|
#+-------+
#|      0|
#|      2|
#+-------+

Now join these back to the original DataFrame:

df1 = df.join(split1, on="groupId", how="inner")
df2 = df.join(split2, on="groupId", how="inner")

df1.show()
3+-------+-----+-----+
#|groupId| col1| col2|
#+-------+-----+-----+
#|      1|val12|val22|
#|      1|val15|val25|
#|      1|val16|val26|
#+-------+-----+-----+

df2.show()
#+-------+-----+-----+
#|groupId| col1| col2|
#+-------+-----+-----+
#|      0|val11|val21|
#|      0|val14|val24|
#|      2|val13|val23|
#+-------+-----+-----+

Upvotes: 4

Related Questions