suiwenfeng
suiwenfeng

Reputation: 2003

scala how to parameterized case class, and pass the case class variable to [T <: Product: TypeTag]

// class definition of RsGoods schema
case class RsGoods(add_time: Int)

// my operation
originRDD.toDF[Schemas.RsGoods]()

// and the function definition
def toDF[T <: Product: TypeTag](): DataFrame = mongoSpark.toDF[T]()

now i defined too many schemas(RsGoods1,RsGoods2,RsGoods3), and more will be added in the future.

so the question is how to pass a case class as a variable to structure the code

Attach sbt dependency

  "org.apache.spark" % "spark-core_2.11" % "2.3.0",
  "org.apache.spark" %% "spark-sql" % "2.3.0",
  "org.mongodb.spark" %% "mongo-spark-connector" % "2.3.1",

Attach the key code snippet

  var originRDD = MongoSpark.load(sc, readConfig)
  val df = table match {
    case "rs_goods_multi" => originRDD.toDF[Schemas.RsGoodsMulti]()
    case "rs_goods" => originRDD.toDF[Schemas.RsGoods]()
    case "ma_item_price" => originRDD.toDF[Schemas.MaItemPrice]()
    case "ma_siteuid" => originRDD.toDF[Schemas.MaSiteuid]()
    case "pi_attribute" => originRDD.toDF[Schemas.PiAttribute]()
    case "pi_attribute_name" => originRDD.toDF[Schemas.PiAttributeName]()
    case "pi_attribute_value" => originRDD.toDF[Schemas.PiAttributeValue]()
    case "pi_attribute_value_name" => originRDD.toDF[Schemas.PiAttributeValueName]()

Upvotes: 0

Views: 732

Answers (1)

sarveshseri
sarveshseri

Reputation: 13985

From what I have understood about your requirement, i think following should be a decent starting point.

def readDataset[A: Encoder](
  spark: SparkSession,
  mongoUrl: String,
  collectionName: String,
  clazz: Class[A]
): Dataset[A] = {
  val config = ReadConfig(
    Map("uri" -> s"$mongoUrl.$collectionName")
  )

  val df = MongoSpark.load(spark, config)

  val fieldNames = clazz.getDeclaredFields.map(f => f.getName).dropRight(1).toList

  val dfWithMatchingFieldNames = df.toDf(fieldNames: _*)

  dfWithMatchingFieldNames.as[A]
}

You can use it like this,

case class RsGoods(add_time: Int)

val spark: SparkSession = ...

import spark.implicts._

val rdGoodsDS = readDataset[RsGoods](
  spark,
  "mongodb://example.com/database",
  "rs_goods",
  classOf[RsGoods]
)

Also, the following two lines,

val fieldNames = clazz.getDeclaredFields.map(f => f.getName).dropRight(1).toList

val dfWithMatchingFieldNames = df.toDf(fieldNames: _*)

are only required because normally Spark reads DataFrames with column names like value1, value2, .... So we want to change the column names to match what we have in our case class.

I am not sure what these "defalut" column names will be because MongoSpark is involved.

You should first check the column names in the df created as following,

val config = ReadConfig(
  Map("uri" -> s"$mongoUrl.$collectionName")
)

val df = MongoSpark.load(spark, config)

If, MongoSpark fixes the problem of these "default" column names and picks the coulmn names from your collection then those 2 lines will not be required and your method will become just this,

def readDataset[A: Encoder](
  spark: SparkSession,
  mongoUrl: String,
  collectionName: String,
): Dataset[A] = {
  val config = ReadConfig(
    Map("uri" -> s"$mongoUrl.$collectionName")
  )

  val df = MongoSpark.load(spark, config)

  df.as[A]
}

And,

val rsGoodsDS = readDataset[RsGoods](
  spark,
  "mongodb://example.com/database",
  "rs_goods"
)

Upvotes: 1

Related Questions