fractalnature
fractalnature

Reputation: 145

Sample a different number of random rows for every group in a dataframe in spark scala

The goal is to sample (without replacement) a different number of rows in a dataframe for every group. The number of rows to sample for a specific group is in another dataframe.

Example: idDF is the dataframe to sample from. The groups are denoted by the ID column. The dataframe, planDF specifies the number of rows to sample for each group where "datesToUse" denotes the number of rows, and "ID" denotes the group. "totalDates" is the total number of rows for that group and may or may not be useful.

The final result should have 3 rows sampled from the first group (ID 1), 2 rows sampled from the second group (ID 2) and 1 row sampled from the third group (ID 3).

val idDF = Seq(
  (1, "2017-10-03"),
  (1, "2017-10-22"),
  (1, "2017-11-01"),
  (1, "2017-10-02"),
  (1, "2017-10-09"),
  (1, "2017-12-24"),
  (1, "2017-10-20"),
  (2, "2017-11-17"),
  (2, "2017-11-12"),
  (2, "2017-12-02"),      
  (2, "2017-10-03"),
  (3, "2017-12-18"),
  (3, "2017-11-21"),
  (3, "2017-12-13"),
  (3, "2017-10-08"),
  (3, "2017-10-16"),
  (3, "2017-12-04")
 ).toDF("ID", "date")

val planDF = Seq(
  (1, 3, 7),
  (2, 2, 4),
  (3, 1, 6)
 ).toDF("ID", "datesToUse", "totalDates")

this is an example of what a resultant dataframe should look like:

+---+----------+
| ID|      date|
+---+----------+
|  1|2017-10-22|
|  1|2017-11-01|
|  1|2017-10-20|
|  2|2017-11-12|
|  2|2017-10-03|
|  3|2017-10-16|
+---+----------+

So far, I tried to use the sample method for DataFrame: https://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/sql/DataFrame.html Here is an example that would work for an entire data frame.

def sampleDF(DF: DataFrame, datesToUse: Int, totalDates: Int): DataFrame = {
  val fraction = datesToUse/totalDates.toFloat.toDouble
  DF.sample(false, fraction)
}

I cant figure out how to use something like this for each group. I tried joining the planDF table to the idDF table and using a window partition.

Another idea I had was to somehow make a new column with randomly labeled True / false and then filter on that column.

Upvotes: 3

Views: 5293

Answers (2)

hoyland
hoyland

Reputation: 1824

Another option staying entirely in Dataframes would be to compute probabilities using your planDF, join with idDF, append a column of random numbers and then filter. Helpfully, sql.functions has a rand function.

import org.apache.spark.sql.functions._

import spark.implicits._

val probabilities = planDF.withColumn("prob", $"datesToUse" / $"totalDates")

val dfWithProbs = idDF.join(probabilities, Seq("ID"))
  .withColumn("rand", rand())
  .where($"rand" < $"prob")

(You'll want to double check that that isn't integer division.)

Upvotes: 4

Leo C
Leo C

Reputation: 22439

With the assumption that your planDF is small enough to be collected, you can use Scala's foldLeft to traverse the id list and accumulate the sample Dataframes per id:

import org.apache.spark.sql.{Row, DataFrame}

def sampleByIdDF(DF: DataFrame, id: Int, datesToUse: Int, totalDates: Int): DataFrame = {
  val fraction = datesToUse.toDouble / totalDates
  DF.where($"id" === id ).sample(false, fraction)
}

val emptyDF = Seq.empty[(Int, String)].toDF("ID", "date")

val planList = planDF.rdd.collect.map{ case Row(x: Int, y: Int, z: Int) => (x, y, z) }
// planList: Array[(Int, Int, Int)] = Array((1,3,7), (2,2,4), (3,1,6))

planList.foldLeft( emptyDF ){
  case (accDF: DataFrame, (id: Int, num: Int, total: Int)) =>
    accDF union sampleByIdDF(idDF, id, num, total)
}
// res1: org.apache.spark.sql.DataFrame = [ID: int, date: string]

// res1.show
// +---+----------+
// | ID|      date|
// +---+----------+
// |  1|2017-10-03|
// |  1|2017-11-01|
// |  1|2017-10-02|
// |  1|2017-12-24|
// |  1|2017-10-20|
// |  2|2017-11-17|
// |  2|2017-11-12|
// |  2|2017-12-02|
// |  3|2017-11-21|
// |  3|2017-12-13|
// +---+----------+

Note that method sample() does not necessarily generate the exact number of samples specified in the method arguments. Here's a relevant SO Q&A.

If your planDF is large, you might have to consider using RDD's aggregate, which has the following signature (skipping the implicit argument):

def aggregate[U](zeroValue: U)(seqOp: (U, T) ⇒ U, combOp: (U, U) ⇒ U): U

It works somewhat like foldLeft, except that it has one accumulation operator within a partition and an additional one to comine results from different partitions.

Upvotes: 1

Related Questions