Naveen Srikanth
Naveen Srikanth

Reputation: 789

pyspark how to save and load one vs rest classifier logistic regression

I am using pyspark 2.4.5 I have a problem with saving and loading one vs rest classifier

Below is the code

 from pyspark.ml.classification import LogisticRegression, OneVsRest
 start=time.time()
 lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)
 # instantiate the One Vs Rest Classifier.
 ovr = OneVsRest(classifier=lr)
 # train the multiclass model.
 ovrModel = ovr.fit(df)
 end=time.time()     

 ovrModel.save('s3://one_vs_Rest_model')

while loading the model I give

 lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)

 # instantiate the One Vs Rest Classifier.
 ovr = OneVsRest(classifier=lr)
 ovr_mdl=ovr.load('s3://one_vs_Rest_model')

I get error as

'requirement failed: Error loading metadata: Expected class name 
org.apache.spark.ml.classification.OneVsRest but found class name 
org.apache.spark.ml.classification.OneVsRestModel'
Traceback (most recent call last):
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/util.py", line 362, in load
return cls.read().load(path)
 File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/util.py", line 300, in load
java_obj = self._jread.load(path)
 File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
answer, self.gateway_client, self.target_id, self.name)
 File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 79, in deco
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
pyspark.sql.utils.IllegalArgumentException: 'requirement failed: Error loading metadata: Expected 
class name org.apache.spark.ml.classification.OneVsRest but found class name 
org.apache.spark.ml.classification.OneVsRestModel'

Upvotes: 0

Views: 813

Answers (1)

lbcommer
lbcommer

Reputation: 1035

The problem is that you have saved a OneVsRestModel object, and you are trying to load a OneVsRest object.

To load the saved model you can just do this:

from pyspark.ml.classification import OneVsRestModel
ovr_mdl = OneVsRestModel.load('s3://one_vs_Rest_model')

Upvotes: 1

Related Questions