Reputation: 742
I have a Dataframe with this structure:
|-- col0: double (nullable = true)
|-- arr: array (nullable = true)
| |-- element: array (containsNull = true)
| | |-- element: double (containsNull = false)
The array column has to save two elements (arrays), created from the element is not missing. As an example, I have this:
|0.0 |[[0.0, 182.0], [1.0, 14.0]]|
|0.0 |[[1.0, 60.0]] |
|1.0 |[[0.0, 3.0], [1.0, 48.0]] |
|2.0 |[[1.0, 6.0], [0.0, 111.0]] |
|0.0 |[[1.0, 4.0], [0.0, 120.0]] |
|2.0 |[[0.0, 21.0]] |
|0.0 |[[0.0, 3.0], [1.0, 13.0]] |
And the desired result is:
|0.0 |[[0.0, 182.0], [1.0, 14.0]]|
|0.0 |[[0.0, 0.0], [1.0, 60.0]] |
|1.0 |[[0.0, 3.0], [1.0, 48.0]] |
|2.0 |[[0.0, 111.0], [1.0, 6.0]] |
|0.0 |[[0.0, 120.0], [1.0, 4.0]] |
|2.0 |[[0.0, 21.0], [1.0, 0.0]] |
|0.0 |[[0.0, 3.0], [1.0, 13.0]] |
So, when the array has 2 elements, nothing to do. But if it has one element, I need to create a second element with the value that is missing (if has an element with value 0.0, I need to create one with value [1.0, 0.0], and if has an element with value 0.0, I need [0.0, 0.0]).
I have tried the following, but it didn't work:
val headValue = udf((arr: Array[Array[Double]], maxValue: Double, minValue: Double) => {
val flatArr = arr.flatMap(_.headOption)
val nArr = arr
if (flatArr.length == 1){
if (flatArr.head == maxValue){
nArr :+ Array (minValue, 0.0)
} else {
nArr :+ Array (maxValue, 0.0)
}
} else {
nArr
}
})
df.withColumn("Test", headValue(df("arrOfarr"), lit(maxValue), lit(minValue) ))
And the error is:
org.apache.spark.SparkException: Failed to execute user defined function(anonfun$20: (array<array<double>>, double, double) => array<array<double>>)
...
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [[D
Upvotes: 0
Views: 251
Reputation: 211
Instead of defining input to UDF as Array
, define it as Seq
and you should be good:
val headValue = udf((arr: Seq[Seq[Double]], maxValue: Double, minValue: Double) => {
val flatArr = arr.flatMap(_.headOption)
val nArr = arr
if (flatArr.length == 1){
if (flatArr.head == maxValue){
nArr :+ Seq(minValue, 0.0)
} else {
nArr :+ Seq(maxValue, 0.0)
}
} else {
nArr
}
})
Upvotes: 1