Cody Berry
Cody Berry

Reputation: 263

Is there a better way to go about this process of trimming my spark DataFrame appropriately?

In the following example, I want to be able to only take the x Ids with the highest counts. x is number of these I want which is determined by a variable called howMany.

For the following example, given this Dataframe:

+------+--+-----+
|query |Id|count|
+------+--+-----+
|query1|11|2    |
|query1|12|1    |
|query2|13|2    |
|query2|14|1    |
|query3|13|2    |
|query4|12|1    |
|query4|11|1    |
|query5|12|1    |
|query5|11|2    |
|query5|14|1    |
|query5|13|3    |
|query6|15|2    |
|query6|16|1    |
|query7|17|1    |
|query8|18|2    |
|query8|13|3    |
|query8|12|1    |
+------+--+-----+

I would like to get the following dataframe if the variable number is 2.

+------+-------+-----+
|query |Ids    |count|
+------+-------+-----+
|query1|[11,12]|2    |
|query2|[13,14]|2    |
|query3|[13]   |2    |
|query4|[12,11]|1    |
|query5|[11,13]|2    |
|query6|[15,16]|2    |
|query7|[17]   |1    |
|query8|[18,13]|2    |
+------+-------+-----+

I then want to remove the count column, but that is trivial.

I have a way to do this, but I think it defeats the purpose of scala all together and completely wastes a lot of runtime. Being new, I am unsure about the best ways to go about this

My current method is to first get a distinct list of the query column and create an iterator. Second I loop through the list using the iterator and trim the dataframe to only the current query in the list using df.select($"eachColumnName"...).where("query".equalTo(iter.next())). I then .limit(howMany) and then groupBy($"query").agg(collect_list($"Id").as("Ids")). Lastly, I have an empty dataframe and add each of these one by one to the empty dataframe and return this newly created dataframe.

df.select($"query").distinct().rdd.map(r => r(0).asInstanceOf[String]).collect().toList
val iter = queries.toIterator
while (iter.hasNext) {
    middleDF = df.select($"query", $"Id", $"count").where($"query".equalTo(iter.next()))
    queryDF = middleDF.sort(col("count").desc).limit(howMany).select(col("query"), col("Ids")).groupBy(col("query")).agg(collect_list("Id").as("Ids"))
    emptyDF.union(queryDF) // Assuming emptyDF is made
}
emptyDF

Upvotes: 0

Views: 46

Answers (1)

Raphael Roth
Raphael Roth

Reputation: 27383

I would do this using Window-Functions to get the rank, then groupBy to aggrgate:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val howMany = 2

val newDF = df
.withColumn("rank",row_number().over(Window.partitionBy($"query").orderBy($"count".desc)))
.where($"rank"<=howMany)
.groupBy($"query")
.agg(
 collect_list($"Id").as("Ids"),
 max($"count").as("count") 
)

Upvotes: 1

Related Questions