Reputation: 17674
I want to develop a custom estimator for spark which handles persistence of the great pipeline API as well. But as How to Roll a Custom Estimator in PySpark mllib put it there is not a lot of documentation out there (yet).
I have some data cleansing code written in spark and would like to wrap it in a custom estimator. Some na-substitutions, column deletions, filtering and basic feature generation are included (e.g. birthdate to age).
ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
What is still pretty unclear to me:
transform
in the custom pipeline model will be used to transform the "fitted" Estimator on new data. Is this correct? If yes how should I transfer the fitted values e.g. the mean age from above into the model?
how to handle persistence? I found some generic loadImpl
method within private spark components but am unsure how to transfer my own parameters e.g. the mean age into the MLReader
/ MLWriter
which are used for serialization.
It would be great if you could help me with a custom estimator - especially with the persistence part.
Upvotes: 6
Views: 2947
Reputation: 18601
The following uses the Scala API but you can easily refactor it to Python if you really want to...
First things first:
.fit()
that returns a Transformer.transform()
and manipulates the DataFrameDefaultParamsWritable
trait + companion object extending DefaultParamsReadable[T]
. a.k.a Stay away from MLReader / MLWriter and keep your code simple.Params
and share it between your Estimator and Model (a.k.a. Transformer)Skeleton code:
// Common Parameters
trait MyCommonParams extends Params {
final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
new StringArrayParam(this, "inputCols", "doc...")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def getInputCols: Array[String] = $(inputCols)
final val meanValues: DoubleArrayParam =
new DoubleArrayParam(this, "meanValues", "doc...")
// more setters and getters
}
// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
override def transformSchema(schema: StructType): StructType = schema // no changes
override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
// your logic here. I can't do all the work for you! ;)
this.setMeanValues(meanValues)
copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]
// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
override def transformSchema(schema: StructType): StructType = schema // no changes
override def transform(dataset: Dataset[_]): DataFrame = {
// your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
// you have access to both inputCols and meanValues here!
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]
With the code above you can Serialize/Deserialize a Pipeline containing a MyMeanValueStuff stage.
Want to look at some real simple implementation of an Estimator? MinMaxScaler! (My example is actually simpler though...)
Upvotes: 2
Reputation: 330203
First of all I believe you're mixing a bit two different things:
Estimators
- which represent stages that can be fit
-ted. Estimator
fit
method takes Dataset
and returns Transformer
(model).Transformers
- which represent stages that can transform
data.When you fit
Pipeline
it fits
all Estimators
and returns PipelineModel
. PipelineModel
can transform
data sequentially calling transform
on all Transformers
in the the model.
how should I transfer the fitted values
There is no single answer to this question. In general you have two options:
Transformer
.Params
of the Transformer
.The first approach is typically used by the built-in Transformer
, but the second one should work in some simple cases.
how to handle persistence
Transformer
is defined only by its Params
you can extend DefaultParamsReadable
.MLWritable
and implement MLWriter
that makes sense for your data. There are multiple examples in Spark source which show how to implement data and metadata reading / writing.If you're looking for an easy to comprehend example take a look a the CountVectorizer(Model)
where:
Estimator
and Transformer
share common Params
.DefaultParamsWriter
/ DefaultParamsReader
.Upvotes: 4