Reputation: 21
How can I get the intersection of values in key value pairs?
I have pairs:
(p, Set(n))
in which I used reduceByKey
and finally got:
(p1, Set(n1, n2)) (p2, Set(n1, n2, n3)) (p3, Set(n2, n3))
What I want is to find n that exist in all of the pairs and put them as value. For the above data, the result would by
(p1, Set(n2)) (p2, Set(n2)), (p3, Set(n2))
As long as I searched, there is no reduceByValue
in spark. The only function that seemed closer to what i want was reduce()
but it didn't work as the result was only one key value pair ((p3, Set(n2)))
.
Is there any way to solve it? Or should i think something else from the start?
Code:
val rRdd = inputFile.map(x => (x._1, Set(x._2)).reduceByKey(_++_)
val wrongRdd = rRdd.reduce{(x, y) => (x._1, x._2.intersect(y._2))}
I can see why wrongRdd
is not correct, I just put it to show how (p3, Set(n2))
resulted from.
Upvotes: 0
Views: 139
Reputation: 22439
You can first reduce
the sets to their intersection (say, s
), then replace (k, v)
with (k, s)
:
val rdd = sc.parallelize(Seq(
("p1", Set("n1", "n2")),
("p2", Set("n1", "n2", "n3")),
("p3", Set("n2", "n3"))
))
val s = rdd.map(_._2).reduce(_ intersect _)
// s: scala.collection.immutable.Set[String] = Set(n2)
rdd.map{ case (k, v) => (k, s) }.collect
// res1: Array[(String, scala.collection.immutable.Set[String])] = Array(
// (p1,Set(n2)), (p2,Set(n2)), (p3,Set(n2))
// )
Upvotes: 2