IshitaV
IshitaV

Reputation: 73

Spark SQL(v2.0) UDAF in Scala returns empty string

While I am trying to create a UDAF for a complex problem of ours, I decided to start with a basic UDAF which returns the column as it is. Since I am new to Spark SQL/ Scala, can somebody please help me and highlight my mistake.

Following is the code:

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row 
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
 
import scala.collection._
 
object MinhashUdaf extends UserDefinedAggregateFunction {
 
  override def inputSchema: StructType = StructType(
    StructField("value", StringType) :: Nil
  )

  override def bufferSchema: StructType = StructType(
    StructField("shingles", (StringType)) :: Nil
  )

  override def dataType: DataType = (StringType)

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ("")   
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, input.toString())   
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}

  override def evaluate(buffer: Row): Any = {
    buffer(0)   
  } 
}

For running the above UDAF, following is the code:

def main(args: Array[String]) {
  val spark: SparkSession = SparkSession.builder
    .master("local[*]")
    .appName("test")
    .getOrCreate();

  import spark.implicits._
 
  val df = spark.read.json("people.json")
  df.createOrReplaceTempView("people")
  val sqlDF = spark.sql("Select name from people")
  sqlDF.show()

  val minhash = df.select(MinhashUdaf(col("name")).as("minhash"))
  minhash.printSchema()
  minhash.show(truncate = false)
}

Since in UDAF I am returning the input as it is, I should get the value of column "name" for each row as it is. Whereas on running the above string, I am returned with an empty string.

Upvotes: 2

Views: 681

Answers (1)

lucas kim
lucas kim

Reputation: 910

You did not implement the merge function.

Using the code below, you can print the value of the column as you want.

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row 
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes

object MinhashUdaf extends UserDefinedAggregateFunction {
    
  override def inputSchema: StructType = StructType(
    StructField("value", StringType) :: Nil
  )
    
  override def bufferSchema: StructType = StructType(
    StructField("shingles", (StringType)) :: Nil
  )
    
  override def dataType: DataType = (StringType)
    
  override def deterministic: Boolean = true
    
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ("") 
  }
    
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 
    buffer.update(0, input.get(0)) 
  }
    
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {  
     buffer1.update(0, buffer2.get(0))
  }
    
  override def evaluate(buffer: Row): Any = { 
     buffer(0) 
  } 
}

Upvotes: 3

Related Questions