I have below dataframe,
|2019|1 |TV |8 |
|2019|2 |AC |10 |
|2018|1 |TV |2 |
|2018|2 |AC |3 |
by using window function I wanted to get below output,
val partitionWindow = Window.partitionBy("year").orderBy("month")
val itemsList= collect_list(struct("item", "quantity")).over(partitionWindow)"year", itemsList as "items")
Expected output:
|year|items |
|2019|[[TV, 8], [AC, 10]]|
|2018|[[TV, 2], [AC, 3]] |
But, when I use window function, there are duplicate rows for each item,
Current output:
|year|items |
|2019|[[TV, 8]] |
|2019|[[TV, 8], [AC, 10]]|
|2018|[[TV, 2]] |
|2018|[[TV, 2], [AC, 3]] |
I wanted to know which is best way to remove the duplicate rows?
Upvotes: 2
Views: 1871
Reputation: 18013
Initially I was looking for an approach without an UDF. That was OK except for once aspect that I could not solve elegantly. With a simple map UDF it is extremely simple, simpler than the other answers. So, for posterity and a little later due to other commitments.
Try this...
import spark.implicits._
import org.apache.spark.sql.functions._
case class abc(year: Int, month: Int, item: String, quantity: Int)
val itemsList= collect_list(struct("month", "item", "quantity"))
val my_udf = udf { items: Seq[Row] =>
val res = { r => (r.getAs[String](1), r.getAs[Int](2)) }
// Gen some data, however, not the thrust of the problem.
val df0 = Seq(abc(2019, 1, "TV", 8), abc(2019, 7, "AC", 10), abc(2018, 1, "TV", 2), abc(2018, 2, "AC", 3), abc(2019, 2, "CO", 7)).toDS()
val df1 = df0.toDF()
val df2 = df1.groupBy($"year")
.agg(itemsList as "items")
.withColumn("sortedCol", sort_array($"items", asc = true))
.withColumn("sortedItems", my_udf(col("sortedCol") ))
Noting the following that you should fix:
|year|sortedItems |
|2019|[[TV, 8], [CO, 7], [AC, 10]]|
|2018|[[TV, 2], [AC, 3]] |
Upvotes: 0
Reputation: 36
I believe the interesting part here is that the aggregated list of items is to be sorted by month. So I've written code in three approaches as :
Creating a sample dataset:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
case class data(year : Int, month : Int, item : String, quantity : Int)
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val inputDF = spark.createDataset(Seq(
data(2018, 2, "AC", 3),
data(2019, 2, "AC", 10),
data(2019, 1, "TV", 2),
data(2018, 1, "TV", 2)
Approach1: Aggregating month, item and quantiy into list and then sorting the items by month using UDF as:
case class items(item : String, quantity : Int)
def getItemsSortedByMonth(itemsRows : Seq[Row]) : Seq[items] = {
if (itemsRows == null || itemsRows.isEmpty) {
else {
itemsRows.sortBy(r => r.getAs[Int]("month"))
.map(r => items(r.getAs[String]("item"), r.getAs[Int]("quantity")))
val itemsSortedByMonthUDF = udf(getItemsSortedByMonth(_: Seq[Row]))
val outputDF = inputDF.groupBy(col("year"))
.agg(collect_list(struct("month", "item", "quantity")).as("items"))
.withColumn("items", itemsSortedByMonthUDF(col("items")))
Approach2: Using window functions
val monthWindowSpec = Window.partitionBy("year").orderBy("month")
val rowNumberWindowSpec = Window.partitionBy("year").orderBy("row_number")
val runningList = collect_list(struct("item", "quantity")). over(rowNumberWindowSpec)
val tempDF = inputDF
// using row_number for continuous ranks if there are multiple items in the same month
.withColumn("row_number", row_number().over(monthWindowSpec))
.withColumn("items", runningList)
.drop("month", "item", "quantity")
val yearToSelect = tempDF.groupBy("year").agg(max("row_number").as("row_number"))
val outputDF = tempDF.join(yearToSelect, Seq("year", "row_number")).drop("row_number")
Edit: Added the third approach for posterity using Dataset API's - groupByKey and mapGroups:
//encoding to data class can be avoided if inputDF is not converted dataset of row objects
val outputDF =[data].groupByKey(_.year).mapGroups{ case (year, rows) =>
val itemsSortedByMonth = rows.toSeq.sortBy(_.month).map(s => items(s.item, s.quantity))
(year, itemsSortedByMonth)
}.toDF("year", "items")
Upvotes: 2