Javiar Sandra
Javiar Sandra

Reputation: 845

Save custom transformers in pyspark

When I implement this part of this python code in Azure Databricks:

class clustomTransformations(Transformer):
    <code>

custom_transformer = customTransformations()
....
pipeline = Pipeline(stages=[custom_transformer, assembler, scaler, rf])
pipeline_model = pipeline.fit(sample_data)
pipeline_model.save(<your path>)

When I attempt to save the pipeline, I get this:

AttributeError: 'customTransformations' object has no attribute '_to_java'

Any work arounds?

Upvotes: 4

Views: 2056

Answers (1)

dportman
dportman

Reputation: 1109

It seems like there is no easy workaround but to try and implement the _to_java method, as is suggested here for StopWordsRemover: Serialize a custom transformer using python to be used within a Pyspark ML pipeline

def _to_java(self):
    """
    Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
    Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
    :return: Java object equivalent to this instance.
    """
    dmp = dill.dumps(self)
    pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
    pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
    sc = SparkContext._active_spark_context
    java_class = sc._gateway.jvm.java.lang.String
    java_array = sc._gateway.new_array(java_class, len(pylist))
    for i in xrange(len(pylist)):
        java_array[i] = pylist[i]
    _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
    _java_obj.setStopWords(java_array)
    return _java_obj

Upvotes: 5

Related Questions