Reputation: 501
I'd like to implement a function in Scala that computes the dot product of two numeric sequences as follows
val x = Seq(1,2,3.0)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z : Double = 90.0
val x = Seq(1,2,3)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z : Int = 90
Notice that if the two sequences are of different types, the result is an Double. If the two sequences are of the same type (e.g. Int), the result is an Int.
I came up with two alternatives but neither meets the requirement as defined above.
Alternative #1:
def dotProduct[T: Numeric](x: Seq[T], y: Seq[T]): T = (for (a <- x; b <- y) yield implicitly[Numeric[T]].times(a, b)).sum
This returns the result in the same type as the input, but it can't take two different types.
Alternative #2:
def dotProduct[A, B](x: Seq[A], y: Seq[B])(implicit nx: Numeric[A], ny: Numeric[B]) = (for (a <- x; b <- y) yield nx.toDouble(a)*ny.toDouble(b)).sum
This works for all numeric sequences. However, it always return a Double, even if the two sequences are of the type Int.
Any suggestion is greatly appreciated.
p.s. The function I implemented above is not "dot product", but simply sum of product of two sequences. Thanks Daniel for pointing it out.
Alternative #3 (slightly better than alternatives #1 and #2):
def sumProduct[T, A <% T, B <% T](x: Seq[A], y: Seq[B])(implicit num: Numeric[T]) = (for (a <- x; b <- y) yield num.times(a,b)).sum
sumProduct(Seq(1,2,3), Seq(4,5,6)) //> res0: Int = 90
sumProduct(Seq(1,2,3.0), Seq(4,5,6)) //> res1: Double = 90.0
sumProduct(Seq(1,2,3), Seq(4,5,6.0)) // Fails!!!
Unfortunately, the View Bound feature (e.g. "<%") will be deprecated in Scala 2.10.
Upvotes: 3
Views: 1406
Reputation: 7016
You could create a typeclass that represents the promotion rules:
trait NumericPromotion[A, B, C] {
def promote(a: A, b: B): (C, C)
}
implicit object IntDoublePromotion extends NumericPromotion[Int, Double, Double] {
def promote(a: Int, b: Double): (Double, Double) = (a.toDouble, b)
}
def dotProduct[A, B, C]
(x: Seq[A], y: Seq[B])
(implicit numEv: Numeric[C], promEv: NumericPromotion[A, B, C])
: C = {
val foo = for {
a <- x
b <- y
} yield {
val (pa, pb) = promEv.promote(a, b)
numEv.times(pa, pb)
}
foo.sum
}
dotProduct[Int, Double, Double](Seq(1, 2, 3), Seq(1.0, 2.0, 3.0))
My typeclass-fu isn't good enough to eliminate the explicit type parameters in the call to dotProduct
, nor could I figure out how to avoid the val foo
inside the method; inlining foo
led to compiler errors. I chalk this up to no having really internalized the implicit resolution rules. Maybe somebody else can get you further.
It's also worth mentioning that this is directional; you couldn't compute dotProduct(Seq(1.0, 2.0, 3.0), Seq(1, 2, 3))
. But that's easy to fix:
implicit def flipNumericPromotion[A, B, C]
(implicit promEv: NumericPromotion[B, A, C])
: NumericPromotion[A, B, C] =
new NumericPromotion[A, B, C] {
override def promote(a: A, b: B): (C, C) = promEv.promote(b, a)
}
It's also worth mentioning that your code doesn't compute a dot product. The dot product of [1, 2, 3]
and [4, 5, 6]
is 4 + 10 + 18 = 32
.
Upvotes: 1