Reputation: 1345
I have the following input array
val bins = (("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))
Basically the strings "bin1" refer to values in a reference column on which dataframe is filtered - a new column is created from another column based on boundry conditions in remaining two doubles in the array
var number_of_dataframes = bins.length
var ctempdf = spark.createDataFrame(sc.emptyRDD[Row],train_data.schema)
ctempdf = ctempdf.withColumn(colName,col(colName))
val t1 = System.nanoTime
for ( x<- 0 to binputs.length-1)
{
var tempdf = train_data.filter(col(refCol) === bins(x)._1)
//println(binputs(x)._1)
tempdf = tempdf.withColumn(colName,
when(col(colName) < bins(x)._2, bins(x)._2)
when(col(colName) > bins(x)._3, bins(x)._3)
otherwise(col(colName)))
ctempdf = ctempdf.union(tempdf)
val duration = (System.nanoTime - t1) / 1e9d
println(duration)
}
The code above works incrementally slowly for every increasing value of bins - Is there a way I can speed this up drastically - because this code is again nested in another loop.
I have used checkpoint / persist / cache and these are not helping
Upvotes: 0
Views: 562
Reputation: 4631
There is no need for iterative union here. Create a literal map<string, struct<double, double>>
using o.a.s.sql.functions.map
(in functional terms it behaves like delayed string => struct<lower: dobule, upper: double>
)
import org.apache.spark.sql.functions._
val bins: Seq[(String, Double Double)] = Seq(
("bin1",1.0,2.0),("bin2",3.0,4.0),("bin3",5.0,6.0))
val binCol = map(bins.map {
case (key, lower, upper) => Seq(
lit(key),
struct(lit(lower) as "lower", lit(upper) as "upper"))
}.flatten: _*)
define expressions like these (these are simple lookups in predefined mapping, so binCol(col(refCol))
is delayed struct<lower: dobule, upper: double>
and the remaining apply
takes the lower
or upper
field):
val lower = binCol(col(refCol))("lower")
val upper = binCol(col(refCol))("upper")
val c = col(colName)
and use CASE ... WHEN ...
(Spark Equivalent of IF Then ELSE)
val result = when(c.between(lower, upper), c)
.when(c < lower, lower)
.when(c > upper, upper)
select and drop NULL
s:
df
.withColumn(colName, result)
// If value is still NULL it means we didn't find refCol key in binCol keys.
// To mimic .filter(col(refCol) === ...) we drop the rows
.na.drop(Seq(colName))
This solution assumes that there are no NULL
values in the colName
at the beginning, but can be easily adjusted to handle cases where this assumption is not satisfied.
If the process is still unclear I'd recommend tracing it step-by-step with literals:
spark.range(1).select(binCol as "map").show(false)
+------------------------------------------------------------+
|map |
+------------------------------------------------------------+
|[bin1 -> [1.0, 2.0], bin2 -> [3.0, 4.0], bin3 -> [5.0, 6.0]]|
+------------------------------------------------------------+
spark.range(1).select(binCol(lit("bin1")) as "value").show(false)
+----------+
|value |
+----------+
|[1.0, 2.0]|
+----------+
spark.range(1).select(binCol(lit("bin1"))("lower") as "value").show
+-----+
|value|
+-----+
| 1.0|
+-----+
and further referring to Querying Spark SQL DataFrame with complex types.
Upvotes: 2