Buck
Buck

Reputation: 501

How to implement generic function in Scala with two argument types?

I'd like to implement a function in Scala that computes the dot product of two numeric sequences as follows

val x = Seq(1,2,3.0)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z  : Double = 90.0

val x = Seq(1,2,3)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z  : Int = 90

Notice that if the two sequences are of different types, the result is an Double. If the two sequences are of the same type (e.g. Int), the result is an Int.

I came up with two alternatives but neither meets the requirement as defined above.

Alternative #1:

def dotProduct[T: Numeric](x: Seq[T], y: Seq[T]): T = (for (a <- x; b <- y) yield implicitly[Numeric[T]].times(a, b)).sum

This returns the result in the same type as the input, but it can't take two different types.

Alternative #2:

def dotProduct[A, B](x: Seq[A], y: Seq[B])(implicit nx: Numeric[A], ny: Numeric[B]) = (for (a <- x; b <- y) yield nx.toDouble(a)*ny.toDouble(b)).sum

This works for all numeric sequences. However, it always return a Double, even if the two sequences are of the type Int.

Any suggestion is greatly appreciated.

p.s. The function I implemented above is not "dot product", but simply sum of product of two sequences. Thanks Daniel for pointing it out.

Alternative #3 (slightly better than alternatives #1 and #2):

def sumProduct[T, A <% T, B <% T](x: Seq[A], y: Seq[B])(implicit num: Numeric[T]) = (for (a <- x; b <- y) yield num.times(a,b)).sum

sumProduct(Seq(1,2,3), Seq(4,5,6))  //> res0: Int = 90
sumProduct(Seq(1,2,3.0), Seq(4,5,6))  //> res1: Double = 90.0
sumProduct(Seq(1,2,3), Seq(4,5,6.0))  // Fails!!!

Unfortunately, the View Bound feature (e.g. "<%") will be deprecated in Scala 2.10.

Upvotes: 3

Views: 1406

Answers (1)

Daniel Yankowsky
Daniel Yankowsky

Reputation: 7016

You could create a typeclass that represents the promotion rules:

trait NumericPromotion[A, B, C] {
  def promote(a: A, b: B): (C, C)
}

implicit object IntDoublePromotion extends NumericPromotion[Int, Double, Double] {
  def promote(a: Int, b: Double): (Double, Double) = (a.toDouble, b)
}

def dotProduct[A, B, C]
              (x: Seq[A], y: Seq[B])
              (implicit numEv: Numeric[C], promEv: NumericPromotion[A, B, C])
              : C = {
  val foo = for {
    a <- x
    b <- y
  } yield {
    val (pa, pb) = promEv.promote(a, b)
    numEv.times(pa, pb)
  }

  foo.sum
}

dotProduct[Int, Double, Double](Seq(1, 2, 3), Seq(1.0, 2.0, 3.0))

My typeclass-fu isn't good enough to eliminate the explicit type parameters in the call to dotProduct, nor could I figure out how to avoid the val foo inside the method; inlining foo led to compiler errors. I chalk this up to no having really internalized the implicit resolution rules. Maybe somebody else can get you further.

It's also worth mentioning that this is directional; you couldn't compute dotProduct(Seq(1.0, 2.0, 3.0), Seq(1, 2, 3)). But that's easy to fix:

implicit def flipNumericPromotion[A, B, C]
                                 (implicit promEv: NumericPromotion[B, A, C])
                                 : NumericPromotion[A, B, C] = 
  new NumericPromotion[A, B, C] {
    override def promote(a: A, b: B): (C, C) = promEv.promote(b, a)
  }

It's also worth mentioning that your code doesn't compute a dot product. The dot product of [1, 2, 3] and [4, 5, 6] is 4 + 10 + 18 = 32.

Upvotes: 1

Related Questions