Wolf Rendall
Wolf Rendall

Reputation: 417

Replace Values in Spark Dataframe based on Map

I have a dataset of integers, some of which are real data and some of which, above a certain threshold, are error codes. I also have a Map of column names to the beginning of their error code range. I would like to use this map to conditionally replace values, for example None if the value of the row in each column is above the start of the error range.

val errors = Map("Col_1" -> 100, "Col_2" -> 10)

val df = Seq(("john", 1, 100), ("jacob", 10, 100), ("heimer", 1000, 
1)).toDF("name", "Col_1", "Col_2")

df.take(3)
// name   | Col_1 | Col_2
// john   | 1     | 1
// jacob  | 10    | 10
// heimer | 1000  | 1

//create some function like this
def fixer = udf((column_value, column_name) => {
    val crit_val = errors(column_name)
    if(column_value >= crit_val) {
        None
    } else {
        column_value
    }
}

//apply it in some way
val fixed_df = df.columns.map(_ -> fixer(_))

//to get output like this:
fixed_df.take(3)
// name   | Col_1 | Col_2
// john   | 1     | 1
// jacob  | 10    | None
// heimer | None  | 1

Upvotes: 0

Views: 2241

Answers (1)

Tzach Zohar
Tzach Zohar

Reputation: 37822

It's not too convenient to do this using a UDF - a UDF expects a specific column (or more than one) and returns a single column, and here you want to handle various different columns. Moreover, the act of checking a threshold and replacing the value with some constant can be performed using Spark's built-in method when and does not require a UDF.

So, here's a way to use when for each column that has some threshold, thus iteratively going through the relevant columns and producing the desired DataFrame (we'll replace "bad" values with null):

import org.apache.spark.sql.functions._
import spark.implicits._

// fold the list of errors, replacing the original column
// with a "corrected" column with same name in each iteration
val newDf = errors.foldLeft(df) { case (tmpDF, (colName, threshold)) =>
  tmpDF.withColumn(colName, when($"$colName" > threshold, null).otherwise($"$colName"))
}

newDf.show()
// +------+-----+-----+
// |  name|Col_1|Col_2|
// +------+-----+-----+
// |  john|    1|    1|
// | jacob|   10| null|
// |heimer| null|    1|
// +------+-----+-----+

Upvotes: 3

Related Questions