user124123
user124123

Reputation: 1683

Iterate over elements of columns Scala

I have a dataframe composed of two Arrays of Doubles. I would like to create a new column that is the result of applying a euclidean distance function to the first two columns, i.e if I had:

 A      B 
(1,2)  (1,3)
(2,3)  (3,4)

Create:

 A      B     C
(1,2)  (1,3)  1
(2,3)  (3,4)  1.4

My data schema is:

df.schema.foreach(println)
StructField(col1,ArrayType(DoubleType,false),false)
StructField(col2,ArrayType(DoubleType,false),true)

Whenever I call this distance function:

def distance(xs: Array[Double], ys: Array[Double]) = {
  sqrt((xs zip ys).map { case (x,y) => pow(y - x, 2) }.sum)
}

I get a type error:

df.withColumn("distances" , distance($"col1",$"col2"))
<console>:68: error: type mismatch;
 found   : org.apache.spark.sql.ColumnName
 required: Array[Double]
       ids_with_predictions_centroids3.withColumn("distances" , distance($"col1",$"col2"))

I understand I have to iterate over the elements of each column, but I cannot find an explanation of how to do this anywhere. I am very new to Scala programming.

Upvotes: 4

Views: 2457

Answers (2)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41987

Spark functions work on column based and your only mistake is that you are mixing column and primitives in the function

And the error message is clear enough which says that you are passing a column in the distance function i.e. $"col1" and $"col2" are columns but the distance function is defined as distance(xs: Array[Double], ys: Array[Double]) taking primitive types.

The solution is to make the distance function fully column based as

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions._

def distance(xs: Column, ys: Column) = {
  sqrt(pow(ys(0)-xs(0), 2) + pow(ys(1)-xs(1), 2))
}

df.withColumn("distances" , distance($"col1",$"col2")).show(false)

which should give you the correct result without errors

+------+------+------------------+
|col1  |col2  |distances         |
+------+------+------------------+
|[1, 2]|[1, 3]|1.0               |
|[2, 3]|[3, 4]|1.4142135623730951|
+------+------+------------------+

I hope the answer is helpful

Upvotes: 3

Shaido
Shaido

Reputation: 28392

To use a custom function on a dataframe you need to define it as an UDF. This can be done, for example, as follows:

val distance = udf((xs: WrappedArray[Double], ys: WrappedArray[Double]) => {
  math.sqrt((xs zip ys).map { case (x,y) => math.pow(y - x, 2) }.sum)
})

df.withColumn("C", distance($"A", $"B")).show()

Note that WrappedArray (or Seq) need to be used here.

Resulting dataframe:

+----------+----------+------------------+
|         A|         B|                 C|
+----------+----------+------------------+
|[1.0, 2.0]|[1.0, 3.0]|               1.0|
|[2.0, 3.0]|[3.0, 4.0]|1.4142135623730951|
+----------+----------+------------------+

Upvotes: 4

Related Questions