Georg Heiler
Georg Heiler

Reputation: 17674

Spark custom estimator including persistence

I want to develop a custom estimator for spark which handles persistence of the great pipeline API as well. But as How to Roll a Custom Estimator in PySpark mllib put it there is not a lot of documentation out there (yet).

I have some data cleansing code written in spark and would like to wrap it in a custom estimator. Some na-substitutions, column deletions, filtering and basic feature generation are included (e.g. birthdate to age).

What is still pretty unclear to me:

It would be great if you could help me with a custom estimator - especially with the persistence part.

Upvotes: 6

Views: 2947

Answers (2)

Marsellus Wallace
Marsellus Wallace

Reputation: 18601

The following uses the Scala API but you can easily refactor it to Python if you really want to...

First things first:

  • Estimator: implements .fit() that returns a Transformer
  • Transformer: implements .transform() and manipulates the DataFrame
  • Serialization/Deserialization: Do your best to use built-in Params and leverage simple DefaultParamsWritable trait + companion object extending DefaultParamsReadable[T]. a.k.a Stay away from MLReader / MLWriter and keep your code simple.
  • Parameters passing: Use a common trait extending the Params and share it between your Estimator and Model (a.k.a. Transformer)

Skeleton code:

// Common Parameters
trait MyCommonParams extends Params {
  final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
    new StringArrayParam(this, "inputCols", "doc...")
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def getInputCols: Array[String] = $(inputCols)

  final val meanValues: DoubleArrayParam = 
    new DoubleArrayParam(this, "meanValues", "doc...")
    // more setters and getters
}

// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
    // your logic here. I can't do all the work for you! ;)
   this.setMeanValues(meanValues)
   copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]

// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def transform(dataset: Dataset[_]): DataFrame = {
      // your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
      // you have access to both inputCols and meanValues here!
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]

With the code above you can Serialize/Deserialize a Pipeline containing a MyMeanValueStuff stage.

Want to look at some real simple implementation of an Estimator? MinMaxScaler! (My example is actually simpler though...)

Upvotes: 2

zero323
zero323

Reputation: 330203

First of all I believe you're mixing a bit two different things:

  • Estimators - which represent stages that can be fit-ted. Estimator fit method takes Dataset and returns Transformer (model).
  • Transformers - which represent stages that can transform data.

When you fit Pipeline it fits all Estimators and returns PipelineModel. PipelineModel can transform data sequentially calling transform on all Transformers in the the model.

how should I transfer the fitted values

There is no single answer to this question. In general you have two options:

  • Pass parameters of the fitted model as the arguments of the Transformer.
  • Make parameters of the fitted model Params of the Transformer.

The first approach is typically used by the built-in Transformer, but the second one should work in some simple cases.

how to handle persistence

  • If Transformer is defined only by its Params you can extend DefaultParamsReadable.
  • If you use more complex arguments you should extend MLWritable and implement MLWriter that makes sense for your data. There are multiple examples in Spark source which show how to implement data and metadata reading / writing.

If you're looking for an easy to comprehend example take a look a the CountVectorizer(Model) where:

Upvotes: 4

Related Questions