Reputation: 1673
Suppose I have an abstract class A
. I also have classes B
and C
which inherit from class A
.
abstract class A {
def x: Int
}
case class B(i: Int) extends A {
override def x = -i
}
case class C(i: Int) extends A {
override def x = i
}
Given these classes, I construct the following RDD:
val data = sc.parallelize(Seq(
Set(B(1), B(2)),
Set(B(1), B(3)),
Set(B(1), B(5))
)).cache
.zipWithIndex
.map {case(k, v) => (v, k)}
I also have the following function that gets an RDD as the input and returns the count of each element:
def f(data: RDD[(Long, Set[A])]) = {
data.flatMap({
case (k, v) => v map { af =>
(af, 1)
}
}).reduceByKey(_ + _)
}
Note that the RDD is accepting type A
. Now, I expect val x = f(data)
to return the counts as expected, as B
is a sub-type of A
, but I get the following compile error:
type mismatch;
found : org.apache.spark.rdd.RDD[(Long, scala.collection.immutable.Set[B])]
required: org.apache.spark.rdd.RDD[(Long, Set[A])]
val x = f(data)
This error goes away if I change the function signature to f(data: RDD[(Long, Set[B])])
; however, I can't do that as I want to use other sub classes in the RDD (like C
).
I have also tried the following approach:
def f[T <: A](data: RDD[(Long, Set[T])]) = {
data.flatMap({
case (k, v) => v map { af =>
(af, 1)
}
}) reduceByKey(_ + _)
}
However, this also gives me the following run-time error:
value reduceByKey is not a member of org.apache.spark.rdd.RDD[(T, Int)]
possible cause: maybe a semicolon is missing before `value reduceByKey'?
}) reduceByKey(_ + _)
I appreciate any help on this.
Upvotes: 2
Views: 523
Reputation: 37435
Set[T]
is invariant on T
, meaning that given A
subtype of B
, Set[A]
is not a subtype nor a supertype of Set[B]
RDD[T]
is also invariant on T
further restricting the options because, even if a covariant Collection[+T]
is used (e.g. a List[+T]
) the same situation will arise.
We can resort to a polymorphic form of the method for an alternative:
What's missing in the version above is a ClassTag
that Spark requires to preserve class information after erasure.
This should work:
import scala.reflect.{ClassTag}
def f[T:ClassTag](data: RDD[(Long, Set[T])]) = {
data.flatMap({
case (k, v) => v map { af =>
(af, 1)
}
}) reduceByKey(_ + _)
}
Let's see:
val intRdd = sparkContext.parallelize(Seq((1l, Set(1,2,3)), (2L, Set(4,5,6))))
val res1= f(intRdd).collect
// Array[(Int, Int)] = Array((4,1), (1,1), (5,1), (6,1), (2,1), (3,1))
val strRdd = sparkContext.parallelize(Seq((1l, Set("a","b","c")), (2L, Set("d","e","f"))))
val res2 = f(strRdd).collect
// Array[(String, Int)] = Array((d,1), (e,1), (a,1), (b,1), (f,1), (c,1))
Upvotes: 2