Reputation: 322
I am using some transformers of Pyspark such as StringIndexer, StandardScaler and more. I first apply those to the training set and then later on I want to use the same transformation objects (same parameters of StringIndexerModel, StandardScalerModel) in order to apply them on the test set. Therefore, I am looking for a way to save those transformation functions as a file. However, I cannot find any related method but only with ml functions such as LogisticRegression. Do you know any possible way to do that? Thanks.
Upvotes: 2
Views: 2108
Reputation: 322
I found an easy solution.
Save the indexer model to a file (on HDFS).
writer = indexerModel._call_java("write")
writer.save("indexerModel")
Load the indexer model from a file (saved on HDFS).
indexer = StringIndexerModel._new_java_obj("org.apache.spark.ml.feature.StringIndexerModel.load", "indexerModel")
indexerModel = StringIndexerModel(indexer)
Upvotes: 3
Reputation: 8571
The output of StringIndexer and StandardScaler are both RDDs, so you can either save the models directly to a file or, more likely what you want, you can persist the results for later computation.
To save to a parquet file call (you might need a schema attached as well) sqlContext.createDataFrame(string_indexed_rdd).write.parquet("indexer.parquet")
. You would then need to program loading this result back from a file when you wanted it.
To persist call string_indexed_rdd.persist()
. This will save the intermediary results in memory for reuse later. You can pass options to save to disk as well if you are memory limited.
If you want to just persist the model itself, you're stuck on an existing bug/missing capability in the api (PR). If the underlying issue was resolved and didn't provide new methods, you need to call some underlying methods manually to get and set the model parameters. Looking through the model code you can see that the Models inherit from a chain of classes, one of which is Params
. This class has the extractParamMap
which pulls out the parameters used in the model. You can then save this in any manner you wish for persisting python dicts. Then you need to create an empty model object and follow that with a call to copy(saved_params)
to pass the persisted parameters into the object.
Something along these lines should work:
def save_params(model, filename):
d = shelve.open(filename)
try:
return d.update(model.extractParamMap())
finally:
d.close()
def load_params(ModelClass, filename):
d = shelve.open(filename)
try:
return ModelClass().copy(dict(d))
finally:
d.close()
Upvotes: 0