alex
alex

Reputation: 2193

Scala Spark function with generic Dataset[T] argument and also returns Dataset[T]?

I understand for Spark to be able to turn a Dataframe into a Dataset[T] of some class it requires an encoder. However I can usually do my processing in a main method with the encoder and call .as[MyClass] like so:

val df = spark.read.parquet("something")
val myDS = df.as[MyClass]

This works as long as there is an encoder defined for MyClass

I want to create a function like this

def hello[T](inputDataSet: Dataset[T])(implicit spark: SparkSession): Dataset[T] = {

    val replacedDataFrame = inputDataSet
      // do some transformation as Dataframe
      .as[T]

    replacedDataFrame

}

where I return also a Dataset[T] . However when I try to cast the dataframe .as[T] it complains "No implicits found". I was just thinking that since it's able to understand what I'm doing when I pass in a Dataset[T] it should be able to understand the reverse but I guess not. Any way around this?

example usecase:

// function to replace a column with values from another DataSet
def swapColumnValue[T,K](inputDS: Dataset[T], joinableDS: Dataset[K])(implicit spark: SparkSession): Dataset[T] = {

    val replacedDataFrame = inputDS
      .join(broadcast(joinable), "col1") // exists in "joinableDS" and "inputDS"
      .withColumnRenamed("col1", "to-drop")
      .withColumnRenamed("col2", "col1") // "col2" exists only in "joinableDS"
      .drop("to-drop") 
      .as[T]

    replacedDataFrame

}

Note this isn't my only usecase. But the problem here is - I pass in a Dataset[T] and after doing some operations on it I would like to specify the return as Dataset[T] as well. Once I do the join it converts the Dataset to a Dataframe and it loses track of what class was defined as T.

Upvotes: 1

Views: 1406

Answers (1)

Ged
Ged

Reputation: 18013

Try this, too hard for me to explain but it solves the error message you get:

import org.apache.spark.sql.functions._
import org.apache.spark.sql._
import spark.implicits._
import org.apache.spark.sql.Encoders

case class T(name: String, age: Long)
case class K(name: String, age2: Long)

val dt = Seq(T("Andy", 32), T("John", 33), T("Bob", 33)).toDS()
dt.show()

val dk = Seq(K("Andy", 32), K("John", 133), K("Bob", 245)).toDS()
dk.show()

implicit val sqlContext: SparkSession = spark

def swapColumnValue[T,K](inputDS: Dataset[T], joinableDS: Dataset[K])(implicit spark: SparkSession, encoder: Encoder[T]): Dataset[T] = {
//def swapColumnValue[T,K](inputDS: Dataset[T], joinableDS: Dataset[K]) : DataFrame = {
    val replacedDataFrame = inputDS
      .join(broadcast(joinableDS), "name")  
      .withColumnRenamed("age", "to-drop")
      .withColumnRenamed("age2", "age")  
      .drop("to-drop") 
      .as[T]
  
    replacedDataFrame
}

val ds = swapColumnValue(dt,dk) 
ds.show(false)

returns:

+----+---+
|name|age|
+----+---+
|Andy| 32|
|John| 33|
| Bob| 33|
+----+---+

+----+----+
|name|age2|
+----+----+
|Andy|  32|
|John| 133|
| Bob| 245|
+----+----+

+----+---+
|name|age|
+----+---+
|Andy|32 |
|John|133|
|Bob |245|
+----+---+

ds is a dataset of type T.

Upvotes: 3

Related Questions