João Victor
João Victor

Reputation: 435

How to create an Estimator that trains new samples after already fitted to initial dataset?

I'm trying to create my own Estimator following this example I found in the Spark source code DeveloperApiExample.scala.

But in this example, everytime I call fit() method in Estimator, it will return a new Model.

I want something like fitting again to train more samples that was not trained yet.

I thought in creating a new method in the Model class to do so. But I'm not sure if it makes sense. It's maybe good to know that my model don't need to process all dataset again to train a new sample and we don't want to change the model structure.

Upvotes: 2

Views: 245

Answers (3)

Alfilercio
Alfilercio

Reputation: 1118

If you know how to improve the training in your model without retraining with the already used data, you can't do it in the same class, because you want a Model that is also a Estimator, but sadly this is not possible directly because both are abstract classes, and can't be used mixed in the same class.

As you say, you can provide in the model a method that will return the Estimator to improve/increase the training.

class MyEstimator extends Estimator[MyModel] {
 ...
}

class MyModel extends Model[MyModel] {
   def retrain: MyEstimator = // Create a instance of my estimator that it carries all the previous knowledge
}

Upvotes: 1

Michael Xu
Michael Xu

Reputation: 424

The base class for a spark ml Estimator is defined here. As you can see, the class method fit is a vanilla call to train the model using the input data.

You should reference something like the LogisticRegression class, specifically the trainOnRows function where the input is an RDD and optionally an initial coefficient matrix (output of a trained model). This will allow you to iteratively train a model on different data sets.

For what you need to achieve, please remember that your algorithm of choice must be able to support iterative updates. For example, glm's, neural networks, tree ensembles etc.

Upvotes: 1

mrk
mrk

Reputation: 10396

You can use PipelineModels to save and load and continue fitting models:

MLlib standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is mostly inspired by the scikit-learn project.

Find exemplary code here.

Upvotes: 0

Related Questions