Dasarathy D R
Dasarathy D R

Reputation: 345

How to avoid using withColumn iteratively in Spark Scala?

Currently, we have a code that has iterative usage of withColumn. When-Otherwise Conditional check & we do arithmetic computations on top of that.

Sample Code:

df.withColumn("col4", when(col("col1")>10, col("col2").+col("col3")).otherwise(col("col2")))

Other arithmetic computations happens iteratively for another 40 Columns. Total Record Count - 2M.

Code that is reiterated:

  df.withColumn(colName1, when((col(amountToSubtract).<("0")) && (col(colName1).===("0")) && (col(amountToSubtract).<=(col(colName2))) && (col(colName2).!==("0")), col(colName2)).otherwise(col(colName1))).
      withColumn(amountToSubtract, when(col(amountToSubtract).!==("0"), col(amountToSubtract).-(col(colName2))).otherwise(col(amountToSubtract))).
      withColumn(colName1, when((col(amountToSubtract).>("0")) && (col(colName1).===("0")), col(colName2).+(col(amountToSubtract))).otherwise(col(colName1))).
      withColumn(amountToSubtract, when((col(amountToSubtract).>("0")) && (col(colName1).===("0")), "0").otherwise(col(amountToSubtract)))

This is followed for 7 or 8 other sets of computation.

And the job hangs at this point for a longer time. Sometimes it throws GC overhead error. Being aware that iterative usage of .withColumn does not benefit performance, I could not find an alternative to implement these above mentioned conditional checks. Kindly assist.

Upvotes: 2

Views: 1405

Answers (2)

Chris Bedford
Chris Bedford

Reputation: 2692

You could try encapsulating all the logic that you currently implement with 'when/otherwise' into one UDF that takes as input an array of all the column values that you need to consider when generating your new column values, and returns as output an array of all the generated column values. Udfs sometimes have their own performance issues, but it might be worth a shot. Here's a simple illustration of the technique I am thinking of:

object SO extends App {

  val sparkSession = SparkSession.builder().appName("simple").master("local[*]").getOrCreate()
  sparkSession.sparkContext.setLogLevel("ERROR")

  import sparkSession.implicits._

  case class Record(col1: Int, col2: Int, amtToSubtract: Int)

  val recs = Seq(
    Record(1, 2, 3),
    Record(11, 2, 3)
  ).toDS()


  val colGenerator : Seq[Int] => Seq[Int] =
    (arr: Seq[Int]) =>  {
      val (in_c1, in_c2, in_amt_sub) = (arr(0), arr(1), arr(2))

      val newColName1_a = if (in_amt_sub < 0 && in_c1 == 0 && in_amt_sub < in_c2 &&  in_c2 != 0) {
        in_c2
      }
      else {
        in_c1
      }
      val newAmtSub_a = if (in_amt_sub != 0) {
        in_amt_sub - in_c2
      } else {
        in_amt_sub
      }

      val newColName1_b  = if (  newAmtSub_a  > 0 &&  newColName1_a  == 0 ) {
        in_c2  + newAmtSub_a
      } else {
        newColName1_a
      }

      val newAmtSub_b = if (newAmtSub_a  > 0  && newColName1_b   == 0) {
        0
      } else {
        newAmtSub_a
      }

      Seq(newColName1_b,  newAmtSub_b)
    }

  val colGeneratorUdf = udf(colGenerator)

  // Here the first column in the generated array is 'col4', the UDF could equivalently generate as many
  // other values as you want from the input array of column values.
  //
  val afterUdf = recs.withColumn("colsInStruct",   colGeneratorUdf (array($"col1", $"col2", $"amtToSubtract")))
  afterUdf.show()
  // RESULT
  //+----+----+-------------+------------+
  //|col1|col2|amtToSubtract|colsInStruct|
  //+----+----+-------------+------------+
  //|   1|   2|            3|      [1, 1]|
  //|  11|   2|            3|     [11, 1]|
  //+----+----+-------------+------------+


}

Upvotes: 1

firsni
firsni

Reputation: 916

I'm not sure of the logic you're trying to apply. But here the idea of the foldleft:

  val colList = List("col1", "col2", "col3")
  val df: DataFrame = ???
  colList.foldLeft(df){case(df, colName1) => df
     .withColumn(colName1, when((col(amountToSubtract).<("0")) && (col(colName1).=== ("0")) && (col(amountToSubtract).<=(col(colName2))) && (col(colName2).!==("0")), col(colName2)).otherwise(col(colName1))).
     .withColumn(amountToSubtract, when(col(amountToSubtract).!==("0"), col(amountToSubtract).-(col(colName2))).otherwise(col(amountToSubtract))).
     .withColumn(colName1, when((col(amountToSubtract).>("0")) && (col(colName1).===("0")), col(colName2).+(col(amountToSubtract))).otherwise(col(colName1))).
     .withColumn(amountToSubtract, when((col(amountToSubtract).>("0")) && (col(colName1).===("0")), "0").otherwise(col(amountToSubtract)))

}

Upvotes: 1

Related Questions