Chris
Chris

Reputation: 3725

Pyspark: Take n samples with balanced classes

I have a fairly large dataset with roughly 5bil. records. I would like to take 1mio random samples out of that. The issue is that the labels are not balanced

+---------+----------+
|    label|     count|
+---------+----------+
|A        | 768866802|
|B        |4241039902|
|C        | 584150833|
+---------+----------+

Label B has a lot more data then the other labels. I know there is a concept of down and up sampling but given the large quantity of data I probably don't have to do that technique since I can easily find 1 mio records from each of the labels.

I was wondering how I could efficiently take ~ 1 mio. random samples (without replacement) so that I have an even amount over all labels ~ 333k in each label. Using PySpark

One idea I have is to split the dataset into 3 different df. Take 300k random samples out of it and stitch them together. But maybe there is more efficient ways of doing it.

Upvotes: 1

Views: 1607

Answers (1)

Cena
Cena

Reputation: 3419

You can create a column with random values and use row_number to filter 1M random samples for each label:

from pyspark.sql.types import *
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.window import Window

n = 333333 # number of samples
df = df.withColumn('rand_col', F.rand())
sample_df1 = df.withColumn("row_num",row_number().over(Window.partitionBy("label")\
                .orderBy("rand_col"))).filter(col("row_num")<=3)\
                .drop("rand_col", "row_num")

sample_df1.groupBy("label").count().show()

This will always give you 1M samples for each label.

Another way of doing this is by stratified sampling using spark's stat.sampleBy

n = 333333
seed = 12345

# Creating a dictionary of fractions for eacch label
fractions = df.groupBy("label").count().withColumn("required_n", n/col("count"))\
                .drop("count").rdd.collectAsMap()

sample_df2 = df.stat.sampleBy("label", fractions, seed)

sample_df2.groupBy("label").count().show()

sampleBy however results in an approximate solution depending on the run and does not guarantee an exact number of records for each label.

Example dataframe:

schema = StructType([StructField("id", IntegerType()), StructField("label", IntegerType())])
data = [[1, 2], [1, 2], [1, 3], [2, 3], [1, 2],[1, 1], [1, 2], [1, 3], [2, 2], [1, 1],[1, 2], [1, 2], [1, 3], [2, 3], [1, 1]]
df = spark.createDataFrame(data,schema=schema)

df.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
|    1|    3|
|    2|    7|
|    3|    5|
+-----+-----+

Method 1:

# Sampling 3 records from each label
n = 3 

# Assign a column with random values
df = df.withColumn('rand_col', F.rand())
sample_df1 = df.withColumn("row_num",row_number().over(Window.partitionBy("label")\
                .orderBy("rand_col"))).filter(col("row_num")<=3)\
                .drop("rand_col", "row_num")

sample_df1.groupBy("label").count().show()
+-----+-----+
|label|count|
+-----+-----+
|    1|    3|
|    2|    3|
|    3|    3|
+-----+-----+

Method 2:

# Sampling 3 records from each label
n = 3
seed = 12

fractions = df.groupBy("label").count().withColumn("required_n", n/col("count"))\
                .drop("count").rdd.collectAsMap()

sample_df2 = df.stat.sampleBy("label", fractions, seed)
sample_df2.groupBy("label").count().show()

+-----+-----+
|label|count|
+-----+-----+
|    1|    3|
|    2|    3|
|    3|    4|
+-----+-----+

As you can see, sampleBy tends to give you an approximately equal distribution. But not exactly. I'd prefer Method 1 for your problem.

Upvotes: 2

Related Questions