Reputation: 3725
I have a fairly large dataset with roughly 5bil. records. I would like to take 1mio random samples out of that. The issue is that the labels are not balanced
+---------+----------+
| label| count|
+---------+----------+
|A | 768866802|
|B |4241039902|
|C | 584150833|
+---------+----------+
Label B has a lot more data then the other labels. I know there is a concept of down and up sampling but given the large quantity of data I probably don't have to do that technique since I can easily find 1 mio records from each of the labels.
I was wondering how I could efficiently take ~ 1 mio. random samples (without replacement) so that I have an even amount over all labels ~ 333k in each label. Using PySpark
One idea I have is to split the dataset into 3 different df. Take 300k random samples out of it and stitch them together. But maybe there is more efficient ways of doing it.
Upvotes: 1
Views: 1607
Reputation: 3419
You can create a column with random values and use row_number
to filter 1M random samples for each label:
from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.window import Window
n = 333333 # number of samples
df = df.withColumn('rand_col', F.rand())
sample_df1 = df.withColumn("row_num",row_number().over(Window.partitionBy("label")\
.orderBy("rand_col"))).filter(col("row_num")<=3)\
.drop("rand_col", "row_num")
sample_df1.groupBy("label").count().show()
This will always give you 1M samples for each label.
Another way of doing this is by stratified sampling using spark's stat.sampleBy
n = 333333
seed = 12345
# Creating a dictionary of fractions for eacch label
fractions = df.groupBy("label").count().withColumn("required_n", n/col("count"))\
.drop("count").rdd.collectAsMap()
sample_df2 = df.stat.sampleBy("label", fractions, seed)
sample_df2.groupBy("label").count().show()
sampleBy
however results in an approximate solution depending on the run and does not guarantee an exact number of records for each label.
Example dataframe:
schema = StructType([StructField("id", IntegerType()), StructField("label", IntegerType())])
data = [[1, 2], [1, 2], [1, 3], [2, 3], [1, 2],[1, 1], [1, 2], [1, 3], [2, 2], [1, 1],[1, 2], [1, 2], [1, 3], [2, 3], [1, 1]]
df = spark.createDataFrame(data,schema=schema)
df.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
| 1| 3|
| 2| 7|
| 3| 5|
+-----+-----+
Method 1:
# Sampling 3 records from each label
n = 3
# Assign a column with random values
df = df.withColumn('rand_col', F.rand())
sample_df1 = df.withColumn("row_num",row_number().over(Window.partitionBy("label")\
.orderBy("rand_col"))).filter(col("row_num")<=3)\
.drop("rand_col", "row_num")
sample_df1.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
| 1| 3|
| 2| 3|
| 3| 3|
+-----+-----+
Method 2:
# Sampling 3 records from each label
n = 3
seed = 12
fractions = df.groupBy("label").count().withColumn("required_n", n/col("count"))\
.drop("count").rdd.collectAsMap()
sample_df2 = df.stat.sampleBy("label", fractions, seed)
sample_df2.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
| 1| 3|
| 2| 3|
| 3| 4|
+-----+-----+
As you can see, sampleBy
tends to give you an approximately equal distribution. But not exactly. I'd prefer Method 1 for your problem.
Upvotes: 2