archerarjun
archerarjun

Reputation: 23

Creating UDF compatible with DataFrame and SQL API

I am trying to write and UDF that would work in Dataframes in Spark SQL.

Here is the code

def Timeformat (timecol1: Int) = {
    if (timecol1 >= 1440)  
        ("%02d:%02d".format((timecol1-1440)/60, (timecol1-1440)%60))  
    else 
        ("%02d:%02d".format((timecol1)/60, (timecol1)%60))
}

sqlContext.udf.register("Timeformat", Timeformat _)

This method works perfectly for the sqlcontext

val dataa = sqlContext.sql("""select Timeformat(abc.time_band) from abc""")

Using DF - Gets an error val fcstdataa = abc.select(Timeformat(abc("time_band_start")))

This method throws an type mismatch error.

<console>:41: error: type mismatch;
 found   : org.apache.spark.sql.Column
 required: Int

When i have re-written the UDF as below, works perfect for the DF but doesnot work in the Sqlcontext. Is there any way to solve this issue without creating multiple UDF's to do the same thing

val Timeformat = udf((timecol1: Int) => 
    if (timecol1 >= 1440)  
        ("%02d:%02d".format((timecol1-1440)/60, (timecol1-1440)%60))  
    else 
        ("%02d:%02d".format((timecol1)/60, (timecol1)%60))
)

I am pretty new to scala and spark, What is the difference between two declarations. Is one method better than other ?

Upvotes: 1

Views: 586

Answers (1)

zero323
zero323

Reputation: 330433

It doesn't really make sense to use UDF here but if you really want this simply don't use anonymous function. Take the function you already have (the one Int => String) and wrap it using UDF:

def Timeformat(timecol1: Int): String = ??? 
sqlContext.udf.register("Timeformat", Timeformat _)
val timeformat_ = udf(Timeformat _)

Alternatively you can callUDF:

import org.apache.spark.sql.functions.callUDF

abc.select(callUDF("Timeformat", $"time_band_start"))

That being said non-UDF solution should be preferred most of the time:

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{when, format_string}

def timeformatExpr(col: Column) = {
  val offset = when(col >= 1440, 1440).otherwise(0)
  val x = ((col - offset) / 60).cast("int")
  val y = (col - offset) % 60
  format_string("%02d:%02d", x, y)
}

which is equivalent to following SQL:

val expr = """CASE
  WHEN time_band >= 1440 THEN
      FORMAT_STRING(
          '%02d:%02d', 
          CAST((time_band - 1440) / 60 AS INT),
          (time_band - 1440) % 60
      )
  ELSE 
      FORMAT_STRING(
          '%02d:%02d', 
          CAST(time_band / 60 AS INT),
          time_band % 60
      )
END"""

which can be used in raw SQL as well as DataFrame with selectExpr or expr function.

Examples:

val df = Seq((1L, 123123), (2L, 42132), (3L, 99)).toDF("id", "time_band")

df.select($"*", timeformatExpr($"time_band").alias("tf")).show
// +---+---------+-------+
// | id|time_band|     tf|
// +---+---------+-------+
// |  1|   123123|2028:03|
// |  2|    42132| 678:12|
// |  3|       99|  01:39|
// +---+---------+-------+

df.registerTempTable("df")

sqlContext.sql(s"SELECT *, $expr AS tf FROM df").show
// +---+---------+-------+
// | id|time_band|     tf|
// +---+---------+-------+
// |  1|   123123|2028:03|
// |  2|    42132| 678:12|
// |  3|       99|  01:39|
// +---+---------+-------+

df.selectExpr("*", s"$expr AS tf").show
// +---+---------+-------+
// | id|time_band|     tf|
// +---+---------+-------+
// |  1|   123123|2028:03|
// |  2|    42132| 678:12|
// |  3|       99|  01:39|
// +---+---------+-------+

Upvotes: 1

Related Questions