Reputation: 73
While I am trying to create a UDAF for a complex problem of ours, I decided to start with a basic UDAF which returns the column as it is. Since I am new to Spark SQL/ Scala, can somebody please help me and highlight my mistake.
Following is the code:
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
import scala.collection._
object MinhashUdaf extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("value", StringType) :: Nil
)
override def bufferSchema: StructType = StructType(
StructField("shingles", (StringType)) :: Nil
)
override def dataType: DataType = (StringType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ("")
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, input.toString())
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
override def evaluate(buffer: Row): Any = {
buffer(0)
}
}
For running the above UDAF, following is the code:
def main(args: Array[String]) {
val spark: SparkSession = SparkSession.builder
.master("local[*]")
.appName("test")
.getOrCreate();
import spark.implicits._
val df = spark.read.json("people.json")
df.createOrReplaceTempView("people")
val sqlDF = spark.sql("Select name from people")
sqlDF.show()
val minhash = df.select(MinhashUdaf(col("name")).as("minhash"))
minhash.printSchema()
minhash.show(truncate = false)
}
Since in UDAF I am returning the input as it is, I should get the value of column "name" for each row as it is. Whereas on running the above string, I am returned with an empty string.
Upvotes: 2
Views: 681
Reputation: 910
You did not implement the merge function.
Using the code below, you can print the value of the column as you want.
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypes
object MinhashUdaf extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(
StructField("value", StringType) :: Nil
)
override def bufferSchema: StructType = StructType(
StructField("shingles", (StringType)) :: Nil
)
override def dataType: DataType = (StringType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ("")
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, input.get(0))
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer2.get(0))
}
override def evaluate(buffer: Row): Any = {
buffer(0)
}
}
Upvotes: 3