Reputation: 23068
I'm looking for a way to join the two following Spark Datasets:
# city_visits:
person_id city timestamp
-----------------------------------------------
1 Paris 2017-01-01 00:00:00
1 Amsterdam 2017-01-03 00:00:00
1 Brussels 2017-01-04 00:00:00
1 London 2017-01-06 00:00:00
2 Berlin 2017-01-01 00:00:00
2 Brussels 2017-01-02 00:00:00
2 Berlin 2017-01-06 00:00:00
2 Hamburg 2017-01-07 00:00:00
# ice_cream_events:
person_id flavour timestamp
-----------------------------------------------
1 Vanilla 2017-01-02 00:12:00
1 Chocolate 2017-01-05 00:18:00
2 Strawberry 2017-01-03 00:09:00
2 Caramel 2017-01-05 00:15:00
So that for each row in city_visits
, the row in ice_cream_events
with same person_id
and next timestamp
value is joined, leading to this output:
person_id city timestamp ic_flavour ic_timestamp
---------------------------------------------------------------------------
1 Paris 2017-01-01 00:00:00 Vanilla 2017-01-02 00:12:00
1 Amsterdam 2017-01-03 00:00:00 Chocolate 2017-01-05 00:18:00
1 Brussels 2017-01-04 00:00:00 Chocolate 2017-01-05 00:18:00
1 London 2017-01-06 00:00:00 null null
2 Berlin 2017-01-01 00:00:00 Strawberry 2017-01-03 00:09:00
2 Brussels 2017-01-02 00:00:00 Strawberry 2017-01-03 00:09:00
2 Berlin 2017-01-06 00:00:00 null null
2 Hamburg 2017-01-07 00:00:00 null null
Closest solution I've had so far is the following, however this obviously joins every row in ice_cream_events
that matches the conditions, not just the first one:
val cv = city_visits.orderBy("person_id", "timestamp")
val ic = ice_cream_events.orderBy("person_id", "timestamp")
val result = cv.join(ic, ic("person_id") === cv("person_id")
&& ic("timestamp") > cv("timestamp"))
Is there a (preferably efficient) way to specify that the join is desired only on the first matching ice_cream_events
row and not all of them?
Upvotes: 2
Views: 7042
Reputation: 341
A request please include sc.parallalize
code in questions. It makes it easier to answer.
val city_visits = sc.parallelize(Seq((1, "Paris", "2017-01-01 00:00:00"),(1, "Amsterdam", "2017-01-03 00:00:00"),(1, "Brussels", "2017-01-04 00:00:00"),(1, "London", "2017-01-06 00:00:00"),(2, "Berlin", "2017-01-01 00:00:00"),(2, "Brussels", "2017-01-02 00:00:00"),(2, "Berlin", "2017-01-06 00:00:00"),(2, "Hamburg", "2017-01-07 00:00:00"))).toDF("person_id", "city", "timestamp")
val ice_cream_events = sc.parallelize(Seq((1, "Vanilla", "2017-01-02 00:12:00"),(1, "Chocolate", "2017-01-05 00:18:00"),(2, "Strawberry", "2017-01-03 00:09:00"), (2, "Caramel", "2017-01-05 00:15:00"))).toDF("person_id", "flavour", "timestamp")
As suggested in the comments you can first do the join which will create all possible row combinations.
val joinedRes = city_visits.as("C").
join(ice_cream_events.as("I")
, joinType = "LEFT_OUTER"
, joinExprs =
$"C.person_id" === $"I.person_id" &&
$"C.timestamp" < $"I.timestamp"
).select($"C.person_id", $"C.city", $"C.timestamp", $"I.flavour".as("ic_flavour"), $"I.timestamp".as("ic_timestamp"))
joinedRes.orderBy($"person_id", $"timestamp").show
And then pick the first record using a groupBy
clause.
import org.apache.spark.sql.functions._
val firstMatchRes = joinedRes.
groupBy($"person_id", $"city", $"timestamp").
agg(first($"ic_flavour"), first($"ic_timestamp"))
firstMatchRes.orderBy($"person_id", $"timestamp").show
Now comes the trickier part. As I faced. The above join creates a ginormous upswell of data when doing the join operation. Spark has to wait till the join is finished to run the groupBy
leading to memory issues.
Use stateful joins. For this, we maintain a state in every executor that will emit only one row per executor using local states within a bloom filter.
import org.apache.spark.sql.functions._
var bloomFilter = breeze.util.BloomFilter.optimallySized[String](city_visits.count(), falsePositiveRate = 0.0000001)
val isFirstOfItsName = udf((uniqueKey: String, joinExprs:Boolean) => if (joinExprs) { // Only update bloom filter if all other expresions are evaluated to true. Dataframe evaluation of join clause order is not guranteed so we have to enforce this here.
val res = bloomFilter.contains(uniqueKey)
bloomFilter += uniqueKey
!res
} else false)
val joinedRes = city_visits.as("C").
join(ice_cream_events.as("I")
, joinType = "LEFT_OUTER"
, joinExprs = isFirstOfItsName(
concat($"C.person_id", $"C.city", $"C.timestamp"), // Unique key to identify first of its kind.
$"C.person_id" === $"I.person_id" && $"C.timestamp" < $"I.timestamp")// All the other join conditions here.
).select($"C.person_id", $"C.city", $"C.timestamp", $"I.flavour".as("ic_flavour"), $"I.timestamp".as("ic_timestamp"))
joinedRes.orderBy($"person_id", $"timestamp").show
Finally to combine results from multiple executors.
val firstMatchRes = joinedRes.
groupBy($"person_id", $"city", $"timestamp").
agg(first($"ic_flavour"), first($"ic_timestamp"))
Upvotes: 2