A. Frank
A. Frank

Reputation: 83

Combine two datasets based on value

I have following two datasets:

val dfA = Seq(
("001", "10", "Cat"),
("001", "20", "Dog"),
("001", "30", "Bear"),
("002", "10", "Mouse"),
("002", "20", "Squirrel"),
("002", "30", "Turtle"),
).toDF("Package", "LineItem", "Animal")

val dfB = Seq(
("001", "", "X", "A"),
("001", "", "Y", "B"),
("002", "", "X", "C"),
("002", "", "Y", "D"),
("002", "20", "X" ,"E")
).toDF("Package", "LineItem", "Flag", "Category")

I need to join them with specific conditions:

a) There is always a row in dfB with the X flag and empty LineItem which should be the default Category for the Package from dfA

b) When there is a LineItem specified in dfB the default Category should be overwritten with the Category associated to this LineItem

Expected output:

+---------+----------+----------+----------+
| Package | LineItem | Animal   | Category |
+---------+----------+----------+----------+
| 001     | 10       | Cat      | A        |
+---------+----------+----------+----------+
| 001     | 20       | Dog      | A        |
+---------+----------+----------+----------+
| 001     | 30       | Bear     | A        |
+---------+----------+----------+----------+
| 002     | 10       | Mouse    | C        |
+---------+----------+----------+----------+
| 002     | 20       | Squirrel | E        |
+---------+----------+----------+----------+
| 002     | 30       | Turtle   | C        |
+---------+----------+----------+----------+

I spend some time on it today, but I don't have an idea how it could be accomplished. I appreciate your assistance. Thanks!

Upvotes: 1

Views: 50

Answers (2)

Fqp
Fqp

Reputation: 143

This should work for you:

import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._

val dfA = Seq(
("001", "10", "Cat"),
("001", "20", "Dog"),
("001", "30", "Bear"),
("002", "10", "Mouse"),
("002", "20", "Squirrel"),
("002", "30", "Turtle")
).toDF("Package", "LineItem", "Animal")

val dfB = Seq(
("001", "", "X", "A"),
("001", "", "Y", "B"),
("002", "", "X", "C"),
("002", "", "Y", "D"),
("002", "20", "X" ,"E")
).toDF("Package", "LineItem", "Flag", "Category")

val result = { 
    dfA.as("a")
    .join(dfB.where('Flag === "X").as("b"), $"a.Package" === $"b.Package" and ($"a.LineItem" === $"b.LineItem" or $"b.LineItem" === ""), "left")
    .withColumn("anyRowsInGroupWithBLineItemDefined", first(when($"b.LineItem" =!= "", lit(true)), ignoreNulls = true).over(Window.partitionBy($"a.Package", $"a.LineItem")).isNotNull)
    .where(!$"anyRowsInGroupWithBLineItemDefined" or ($"anyRowsInGroupWithBLineItemDefined" and $"b.LineItem" =!= ""))
    .select($"a.Package", $"a.LineItem", $"a.Animal", $"b.Category")
}

result.orderBy($"a.Package", $"a.LineItem").show(false)

// +-------+--------+--------+--------+
// |Package|LineItem|Animal  |Category|
// +-------+--------+--------+--------+
// |001    |10      |Cat     |A       |
// |001    |20      |Dog     |A       |
// |001    |30      |Bear    |A       |
// |002    |10      |Mouse   |C       |
// |002    |20      |Squirrel|E       |
// |002    |30      |Turtle  |C       |
// +-------+--------+--------+--------+

The "tricky" part is calculating whether or not there are any rows with LineItem defined in dfB for a given Package, LineItem in dfA. You can see how I perform this calculation in anyRowsInGroupWithBLineItemDefined which involves the use of a window function. Other than that, it's just a normal SQL programming exercise.

Also want to note that this code should be more efficient than the other solution as here we only shuffle the data twice (during join and during window function) and only read in each dataset once.

Upvotes: 0

User9123
User9123

Reputation: 1733

You can use two join + when clause:

val dfC = dfA
  .join(dfB, dfB.col("Flag") === "X" && dfA.col("LineItem") === dfB.col("LineItem") && dfA.col("Package") === dfB.col("Package"))
  .select(dfA.col("Package").as("priorPackage"), dfA.col("LineItem").as("priorLineItem"), dfB.col("Category").as("priorCategory"))
  .as("dfC")

val dfD = dfA
  .join(dfB, dfB.col("LineItem") === "" && dfB.col("Flag") === "X" && dfA.col("Package") === dfB.col("Package"), "left_outer")
  .join(dfC, dfA.col("LineItem") === dfC.col("priorLineItem") && dfA.col("Package") === dfC.col("priorPackage"), "left_outer")
  .select(
    dfA.col("package"),
    dfA.col("LineItem"),
    dfA.col("Animal"),
    when(dfC.col("priorCategory").isNotNull, dfC.col("priorCategory")).otherwise(dfB.col("Category")).as("Category")
  )

dfD.show()

Upvotes: 1

Related Questions