Reputation: 29
I have a Dataframe df
that has, among others, a column of groupID
; that is, each observation belongs to a specific group. In total there are 8 groups. I would like to sample from each groupID
a certain percent of observations (say, 20%). Here is my approach of doing this:
val sample_df = for ( i <- Array.range(0,7) ) yield {
val sel_df = df.filter($"groupID"===i)
sel_df.sample(false,0.2,seed1)
}
The result of this code is:
Array[org.apache.spark.sql.DataFrame] = Array([text: string, groupID: int], [text: string, groupID: int])
I applied flatMap()
on sample_df
, but I got an error:
val flat_df = sample_df.flatMap(x => x)
<console>:59: error: type mismatch;
found: org.apache.spark.sql.DataFrame
required: scala.collection.GenTraversableOnce[?]
How can I get a sampled dataframe?
Upvotes: 1
Views: 4780
Reputation: 331
It seems to me you just want to take a 20% sample of the entire dataframe? If so, then there is no reason to create 8 different dataframes and then union them back.
df.sample(false, 0.2, seed)
will do the trick. If you want to do different fractions for each groupID then check out df.stat.sampleBy
. If you want to be sure that there is exactly 20% of each class in the sample then you'll have to convert to a PairRDD and use stratified sampling like:
df.rdd.map(row => (row(groupIDIndex), row)).sampleByKeyExact(false, Map(0 -> 0.2, 1 -> 0.2, ..., 8 -> 0.2), seed)
Upvotes: 0
Reputation: 4925
I guess you wanna sample evenly on each group.
sample_df.reduceLeft((result, df) => result.unionAll(df))
Upvotes: 1
Reputation: 252
As far as I understood, you are trying to get RDD
of Row
. For that you can simply call:
val rows: RDD[Row] = sample_df.rdd
To explain the error you get better, flatMap requires something traversable like Option
but you supplied just a Row
.
Also, to get all data to the driver, you can call:
val rows: Array[Row] = sample_df.collect
Upvotes: 2