Snek
Snek

Reputation: 59

Extending DefaultParamsReadable and DefaultParamsWritable not allowing reading of custom model

Good day,

I have been struggling for a few days to save a custom transformer that is part of a large pipeline of stages. I have a transformer that is completely defined by its params. I have an estimator which in it's fit method will generate a matrix and then set the transformer parameters accordingly so that I can use DefaultParamsReadable and DefaultParamsReadable to take advantage of the serialisation/deserialisation already present in util.ReadWrite.scala.

My summarised code is as follows (includes important aspects):

...
import org.apache.spark.ml.util._
...

// trait to implement in Estimator and Transformer for params
trait NBParams extends Params {
  
  final val featuresCol= new Param[String](this, "featuresCol", "The input column")
  setDefault(featuresCol, "_tfIdfOut")

  final val labelCol = new Param[String](this, "labelCol", "The labels column")
  setDefault(labelCol, "P_Root_Code_Index")
  
  final val predictionsCol = new Param[String](this, "predictionsCol", "The output column")
  setDefault(predictionsCol, "NBOutput")
  
  final val ratioMatrix = new Param[DenseMatrix](this, "ratioMatrix", "The transformation matrix")
  
  def getfeaturesCol: String = $(featuresCol)  
  def getlabelCol: String = $(labelCol)
  def getPredictionCol: String = $(predictionsCol)  
  def getRatioMatrix: DenseMatrix = $(ratioMatrix) 
  
}


// Estimator
class CustomNaiveBayes(override val uid: String, val alpha: Double) 
  extends Estimator[CustomNaiveBayesModel] with NBParams with DefaultParamsWritable {

      def copy(extra: ParamMap): CustomNaiveBayes = {
        defaultCopy(extra)
      }

      def setFeaturesCol(value: String): this.type = set(featuresCol, value) 

      def setLabelCol(value: String): this.type = set(labelCol, value) 

      def setPredictionCol(value: String): this.type = set(predictionsCol, value) 
    
      def setRatioMatrix(value: DenseMatrix): this.type = set(ratioMatrix, value) 
    
      override def transformSchema(schema: StructType): StructType = {...}
    
      override def fit(ds: Dataset[_]): CustomNaiveBayesModel = {
        ...
        val model = new CustomNaiveBayesModel(uid)
        model
          .setRatioMatrix(ratioMatrix)
          .setFeaturesCol($(featuresCol))
          .setLabelCol($(labelCol))
          .setPredictionCol($(predictionsCol))
    }
}

// companion object for Estimator
object CustomNaiveBayes extends DefaultParamsReadable[CustomNaiveBayes]{
  override def load(path: String): CustomNaiveBayes = super.load(path)
}

// Transformer
class CustomNaiveBayesModel(override val uid: String) 
  extends Model[CustomNaiveBayesModel] with NBParams with DefaultParamsWritable {  
    
  def this() = this(Identifiable.randomUID("customnaivebayes"))
   
  def copy(extra: ParamMap): CustomNaiveBayesModel = {defaultCopy(extra)}
    
  def setFeaturesCol(value: String): this.type = set(featuresCol, value) 
    
  def setLabelCol(value: String): this.type = set(labelCol, value) 
    
  def setPredictionCol(value: String): this.type = set(predictionsCol, value) 
    
  def setRatioMatrix(value: DenseMatrix): this.type = set(ratioMatrix, value) 

  override def transformSchema(schema: StructType): StructType = {...}
  }

  def transform(dataset: Dataset[_]): DataFrame = {...}
}


// companion object for Transformer
object CustomNaiveBayesModel extends DefaultParamsReadable[CustomNaiveBayesModel] 

When I add this Model as part of a pipeline and fit the pipeline, all runs ok. When I save the pipeline, there are no errors. However, when I attempt to load the pipeline in I get the following error:

NoSuchMethodException: $line3b380bcad77e4e84ae25a6bfb1f3ec0d45.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$$$6fa979eb27fa6bf89c6b6d1b271932c$$$$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$CustomNaiveBayesModel.read()

To save the pipeline, which includes a number of other transformers related to NLP pre-processing, I run

fittedModelRootCode.write.save("path")

and to then load it (where the failure occurs) I run

import org.apache.spark.ml.PipelineModel
val fittedModelRootCode = PipelineModel.load("path")

The model itself appears to be working well but I cannot afford to retrain the model on a dataset every time I wish to use it. Does anyone have any ideas why even with the companion object, the read() method appears to be unavailable?

Notes:

Any help would be greatly appreciated.

Upvotes: 3

Views: 365

Answers (1)

Boris Azanov
Boris Azanov

Reputation: 4491

Since you extend the CustomNaiveBayesModel companion object by DefaultParamsReadable, I think you should use the companion object CustomNaiveBayesModel for loading the model. Here I write some code for saving and loading models and it works properly:

import org.apache.spark.SparkConf
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import path.to.CustomNaiveBayesModel


object SavingModelApp extends App {

  val spark: SparkSession = SparkSession.builder().config(
    new SparkConf()
      .setMaster("local[*]")
      .setAppName("Test app")
      .set("spark.driver.host", "localhost")
      .set("spark.ui.enabled", "false")
  ).getOrCreate()

  val training = spark.createDataFrame(Seq(
    (0L, "a b c d e spark", 1.0),
    (1L, "b d", 0.0),
    (2L, "spark f g h", 1.0),
    (3L, "hadoop mapreduce", 0.0)
  )).toDF("id", "text", "label")
  val fittedModelRootCode: PipelineModel = new Pipeline().setStages(Array(new CustomNaiveBayesModel())).fit(training)
  fittedModelRootCode.write.save("path/to/model")
  val mod = PipelineModel.load("path/to/model")
}

I think your mistake is using PipelineModel.load for loading the concrete model.

My environment:

scalaVersion := "2.12.6"
scalacOptions := Seq(
  "-encoding", "UTF-8", "-target:jvm-1.8", "-deprecation",
  "-feature", "-unchecked", "-language:implicitConversions", "-language:postfixOps")

libraryDependencies += "org.apache.spark" %% "spark-core" % "3.1.1",
libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.1.1"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.1.1"

Upvotes: 0

Related Questions