Ashkan
Ashkan

Reputation: 1673

Spark Scala: Pass a sub type to a function accepting the parent type

Suppose I have an abstract class A. I also have classes B and C which inherit from class A.

abstract class A {
  def x: Int
}
case class B(i: Int) extends A {
  override def x = -i
}
case class C(i: Int) extends A {
  override def x = i
}

Given these classes, I construct the following RDD:

val data = sc.parallelize(Seq(
      Set(B(1), B(2)),
      Set(B(1), B(3)),
      Set(B(1), B(5))
    )).cache
      .zipWithIndex
      .map {case(k, v) => (v, k)}

I also have the following function that gets an RDD as the input and returns the count of each element:

def f(data: RDD[(Long, Set[A])]) = {
  data.flatMap({
    case (k, v) => v map { af =>
      (af, 1)
    }
  }).reduceByKey(_ + _)
}

Note that the RDD is accepting type A. Now, I expect val x = f(data) to return the counts as expected, as B is a sub-type of A, but I get the following compile error:

type mismatch;
 found   : org.apache.spark.rdd.RDD[(Long, scala.collection.immutable.Set[B])]
 required: org.apache.spark.rdd.RDD[(Long, Set[A])]
    val x = f(data)

This error goes away if I change the function signature to f(data: RDD[(Long, Set[B])]); however, I can't do that as I want to use other sub classes in the RDD (like C).

I have also tried the following approach:

def f[T <: A](data: RDD[(Long, Set[T])]) = {
  data.flatMap({
    case (k, v) => v map { af =>
      (af, 1)
    }
  }) reduceByKey(_ + _)
}

However, this also gives me the following run-time error:

value reduceByKey is not a member of org.apache.spark.rdd.RDD[(T, Int)]
possible cause: maybe a semicolon is missing before `value reduceByKey'?
      }) reduceByKey(_ + _)

I appreciate any help on this.

Upvotes: 2

Views: 523

Answers (1)

maasg
maasg

Reputation: 37435

Set[T] is invariant on T, meaning that given A subtype of B, Set[A] is not a subtype nor a supertype of Set[B] RDD[T] is also invariant on T further restricting the options because, even if a covariant Collection[+T] is used (e.g. a List[+T]) the same situation will arise.

We can resort to a polymorphic form of the method for an alternative: What's missing in the version above is a ClassTag that Spark requires to preserve class information after erasure.

This should work:

import scala.reflect.{ClassTag}
def f[T:ClassTag](data: RDD[(Long, Set[T])]) = {
  data.flatMap({
    case (k, v) => v map { af =>
      (af, 1)
    }
  }) reduceByKey(_ + _)
}

Let's see:

val intRdd = sparkContext.parallelize(Seq((1l, Set(1,2,3)), (2L, Set(4,5,6))))
val res1= f(intRdd).collect
// Array[(Int, Int)] = Array((4,1), (1,1), (5,1), (6,1), (2,1), (3,1))

val strRdd = sparkContext.parallelize(Seq((1l, Set("a","b","c")), (2L, Set("d","e","f"))))
val res2 = f(strRdd).collect
// Array[(String, Int)] = Array((d,1), (e,1), (a,1), (b,1), (f,1), (c,1))

Upvotes: 2

Related Questions