Leothorn
Leothorn

Reputation: 1345

How to create a new column for dataset using ".withColumn" with many conditions in Scala Spark

I have the following input array

val bins = (("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))

Basically the strings "bin1" refer to values in a reference column on which dataframe is filtered - a new column is created from another column based on boundry conditions in remaining two doubles in the array

var number_of_dataframes = bins.length
var ctempdf = spark.createDataFrame(sc.emptyRDD[Row],train_data.schema)
ctempdf = ctempdf.withColumn(colName,col(colName))
val t1 = System.nanoTime
for ( x<- 0 to binputs.length-1)

{
      var tempdf = train_data.filter(col(refCol) === bins(x)._1)
      //println(binputs(x)._1)
      tempdf = tempdf.withColumn(colName,
                                 when(col(colName) < bins(x)._2, bins(x)._2)
                                 when(col(colName) > bins(x)._3, bins(x)._3)
                                 otherwise(col(colName)))
      ctempdf = ctempdf.union(tempdf)
val duration = (System.nanoTime - t1) / 1e9d
println(duration)     
}

The code above works incrementally slowly for every increasing value of bins - Is there a way I can speed this up drastically - because this code is again nested in another loop.

I have used checkpoint / persist / cache and these are not helping

Upvotes: 0

Views: 562

Answers (1)

10465355
10465355

Reputation: 4631

There is no need for iterative union here. Create a literal map<string, struct<double, double>> using o.a.s.sql.functions.map (in functional terms it behaves like delayed string => struct<lower: dobule, upper: double>)

import org.apache.spark.sql.functions._

val bins: Seq[(String, Double Double)] = Seq(
  ("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))

val binCol = map(bins.map { 
  case (key, lower, upper) => Seq(
    lit(key), 
    struct(lit(lower) as "lower", lit(upper) as "upper")) 
}.flatten: _*)

define expressions like these (these are simple lookups in predefined mapping, so binCol(col(refCol)) is delayed struct<lower: dobule, upper: double> and the remaining apply takes the lower or upper field):

val lower = binCol(col(refCol))("lower")
val upper =  binCol(col(refCol))("upper")
val c = col(colName)

and use CASE ... WHEN ... (Spark Equivalent of IF Then ELSE)

val result = when(c.between(lower, upper), c)
  .when(c < lower, lower)
  .when(c > upper, upper)

select and drop NULLs:

df
  .withColumn(colName, result)
  // If value is still NULL it means we didn't find refCol key in binCol keys.
  // To mimic .filter(col(refCol) === ...) we drop the rows
  .na.drop(Seq(colName))

This solution assumes that there are no NULL values in the colName at the beginning, but can be easily adjusted to handle cases where this assumption is not satisfied.

If the process is still unclear I'd recommend tracing it step-by-step with literals:

spark.range(1).select(binCol as "map").show(false)
+------------------------------------------------------------+
|map                                                         |
+------------------------------------------------------------+
|[bin1 -> [1.0, 2.0], bin2 -> [3.0, 4.0], bin3 -> [5.0, 6.0]]|
+------------------------------------------------------------+
spark.range(1).select(binCol(lit("bin1")) as "value").show(false)
+----------+
|value     |
+----------+
|[1.0, 2.0]|
+----------+
spark.range(1).select(binCol(lit("bin1"))("lower") as "value").show
+-----+
|value|
+-----+
|  1.0|
+-----+

and further referring to Querying Spark SQL DataFrame with complex types.

Upvotes: 2

Related Questions