aigujin
aigujin

Reputation: 29

For loop Spark dataframe

I have a Dataframe df that has, among others, a column of groupID; that is, each observation belongs to a specific group. In total there are 8 groups. I would like to sample from each groupID a certain percent of observations (say, 20%). Here is my approach of doing this:

val sample_df = for ( i <- Array.range(0,7) ) yield {  
             val sel_df = df.filter($"groupID"===i)  
             sel_df.sample(false,0.2,seed1)  
             }  

The result of this code is:

Array[org.apache.spark.sql.DataFrame] = Array([text: string, groupID: int], [text: string, groupID: int])

I applied flatMap() on sample_df, but I got an error:

val flat_df = sample_df.flatMap(x => x)
         <console>:59: error: type mismatch;
         found: org.apache.spark.sql.DataFrame
         required: scala.collection.GenTraversableOnce[?]

How can I get a sampled dataframe?

Upvotes: 1

Views: 4780

Answers (3)

Seth Hendrickson
Seth Hendrickson

Reputation: 331

It seems to me you just want to take a 20% sample of the entire dataframe? If so, then there is no reason to create 8 different dataframes and then union them back.

df.sample(false, 0.2, seed)

will do the trick. If you want to do different fractions for each groupID then check out df.stat.sampleBy. If you want to be sure that there is exactly 20% of each class in the sample then you'll have to convert to a PairRDD and use stratified sampling like:

df.rdd.map(row => (row(groupIDIndex), row)).sampleByKeyExact(false, Map(0 -> 0.2, 1 -> 0.2, ..., 8 -> 0.2), seed)

Upvotes: 0

Rockie Yang
Rockie Yang

Reputation: 4925

I guess you wanna sample evenly on each group.

sample_df.reduceLeft((result, df) => result.unionAll(df))

Upvotes: 1

Furkan Varol
Furkan Varol

Reputation: 252

As far as I understood, you are trying to get RDD of Row. For that you can simply call:

val rows: RDD[Row] = sample_df.rdd

To explain the error you get better, flatMap requires something traversable like Option but you supplied just a Row.

Also, to get all data to the driver, you can call:

val rows: Array[Row] = sample_df.collect

Upvotes: 2

Related Questions