Brooklyn Knightley
Brooklyn Knightley

Reputation: 71

Elegant way to apply UDF by condition

I have some input files, and all the files share the same schema. They both have a field named channel_id, but for file1, channel_id = 1, for file2, channel_id = 2.

I need to do some ETL on these files. However, for different files, the logic is different. For example, there is a UDF to calculate channel_name

val getChannelNameUdf : UserDefinedFunction = udf((channelId: Integer) => {
    if (channelId == 1) {
      "English"
    } else if (channelId == 2) {
      "French"
    } else {
      ""
    }
  })

As we have several channels, it seems not elegant to use if-else. Are there more elegant ways or suitable design patterns to write the code? Thanks a lot.

Upvotes: 1

Views: 828

Answers (2)

blackbishop
blackbishop

Reputation: 32650

Are there more elegant ways or suitable design patterns to write the code?

Yes! A simple and efficient way of doing this would be to use join.

You can have a file with all the referential of channels, say it has this structure: channel_id, channel_name, and then join the 2 DataFrames. Something like this:

val df_channels = spark.read.csv("/path/to/channel_file.csv")

val result = df.join(df_channels, Seq("channel_id"),"left")

Upvotes: 1

baitmbarek
baitmbarek

Reputation: 2518

Hi Brooklyn and welcome to StackOverflow,

You can use a Pattern matching in your UDF but I'd suggest you to use the when built-in function instead of defining your own UDF.

To answer your request, here's the code you may need :

val getChannelNameUdf = udf[String, Int] { _ match {
  case 1 => "English"
  case 2 => "French"
  case _ => ""
}}

or even better, just anonymous functions :

val getChannelNameUdf = udf[String, Int] {
  case 1 => "English"
  case 2 => "French"
  case _ => ""
}

Here's an example using the when built-in function :

val getChannelName = {col: Column =>
  when(col === 1, "English").when(col === 2, "French").otherwise("")
}
df.withColumn("channelName", getChannelName($"channelId"))

Edit : For a more generic approach you can use the following definitions :

val rules = Map(1 -> "English", 2 -> "French")
val getChannelName = {col: Column =>
  rules.foldLeft(lit("")){case (c, (i,label)) =>
    when(col === i, label).otherwise(c)
  }
}

and then

df.withColumn("channelName", getChannelName($"channelId"))

Upvotes: 1

Related Questions