J  Calbreath
J Calbreath

Reputation: 2705

SumProduct in Spark DataFrame

I want to create essentially a sumproduct across columns in a Spark DataFrame. I have a DataFrame that looks like this:

id    val1   val2   val3   val4
123   10     5      7      5

I also have a Map that looks like:

val coefficents = Map("val1" -> 1, "val2" -> 2, "val3" -> 3, "val4" -> 4)

I want to take the value in each column of the DataFrame, multiply it by the corresponding value from the map, and return the result in a new column so essentially:

(10*1) + (5*2) + (7*3) + (5*4) = 61

I tried this:

val myDF1 = myDF.withColumn("mySum", {var a:Double = 0.0; for ((k,v) <- coefficients) a + (col(k).cast(DoubleType)*coefficients(k));a})

but got an error that the "+" method was overloaded. Even if I solved that, I'm not sure this would work. Any ideas? I could always dynamically build a SQL query as text string and do it that way but I was hoping for something a little more eloquent.

Any ideas are appreciated.

Upvotes: 0

Views: 1702

Answers (3)

Dylan
Dylan

Reputation: 13924

It looks like the issue is that you aren't actually doing anything with a

for((k, v) <- coefficients) a + ...

You probably meant a += ...


Also, some advice for cleaning up the block of code inside the withColumn call:

You don't need to call coefficients(k) because you've already got its value in v from for((k,v) <- coefficients)

Scala is pretty good at making one-liners, but it's kinda cheating if you have to put semicolons in that one line :P I'd suggest breaking up the sum calculation section into one line per expression.

The sum expression could be rewritten as a fold which avoids using a var (idiomatic Scala usually avoids vars), e.g.

import org.apache.spark.sql.functions.lit

coefficients.foldLeft(lit(0.0)){ 
  case (sumSoFar, (k,v)) => col(k).cast(DoubleType) * v + sumSoFar
}

Upvotes: 2

zero323
zero323

Reputation: 330093

Problem with your code is that you try to add a Column to Double. cast(DoubleType) affects only a type of stored value, not a type of column itself. Since Double doesn't provide *(x: org.apache.spark.sql.Column): org.apache.spark.sql.Column method everything fails.

To make it work you can for example do something like this:

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{col, lit}

val df = sc.parallelize(Seq(
    (123, 10, 5, 7, 5), (456,  1, 1, 1, 1)
)).toDF("k", "val1", "val2", "val3", "val4")

val coefficients = Map("val1" -> 1, "val2" -> 2, "val3" -> 3, "val4" -> 4)

val dotProduct: Column = coefficients
  // To be explicit you can replace
  // col(k) * v with col(k) * lit(v)
  // but it is not required here
  // since we use * f Column.* method not Int.*
  .map{ case (k, v) => col(k) * v }  // * -> Column.*
  .reduce(_ + _)  // + -> Column.+

df.withColumn("mySum", dotProduct).show
// +---+----+----+----+----+-----+
// |  k|val1|val2|val3|val4|mySum|
// +---+----+----+----+----+-----+
// |123|  10|   5|   7|   5|   61|
// |456|   1|   1|   1|   1|   10|
// +---+----+----+----+----+-----+

Upvotes: 2

Rohan Aletty
Rohan Aletty

Reputation: 2442

I'm not sure if this is possible through the DataFrame API since you are only able to work with columns and not any predefined closures (e.g. your parameter map).

I've outlined a way below using the underlying RDD of the DataFrame:

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

// Initializing your input example.
val df1 = sc.parallelize(Seq((123, 10, 5, 7, 5))).toDF("id", "val1", "val2", "val3", "val4")

// Return column names as an array
val names = df1.columns

// Grab underlying RDD and zip elements with column names
val rdd1 = df1.rdd.map(row => (0 until row.length).map(row.getInt(_)).zip(names))

// Tack on accumulated total to the existing row
val rdd2 = rdd0.map { seq => Row.fromSeq(seq.map(_._1) :+ seq.map { case (value: Int, name: String) => value * coefficents.getOrElse(name, 0) }.sum) }

// Create output schema (with total)
val totalSchema = StructType(df1.schema.fields :+ StructField("total", IntegerType))

// Apply schema to create output dataframe
val df2 = sqlContext.createDataFrame(rdd1, totalSchema)

// Show output:
df2.show()
...
+---+----+----+----+----+-----+
| id|val1|val2|val3|val4|total|
+---+----+----+----+----+-----+
|123|  10|   5|   7|   5|   61|
+---+----+----+----+----+-----+

Upvotes: 0

Related Questions