Reading a custom pyspark transformer

After messing with this for quite a while, in Spark 2.3 I am finally able to get a pure python custom transformer saved. But I get an error while loading the transformer back.

I checked the content of what was saved and find all the relevant variable saved in the file on HDFS. Would be great if someone can spot what I am missing to do in this simple transformer.

from pyspark.ml import Transformer
from pyspark.ml.param.shared import Param,Params,TypeConverters

class AggregateTransformer(Transformer,DefaultParamsWritable,DefaultParamsReadable):
    aggCols = Param(Params._dummy(), "aggCols", "",TypeConverters.toListString)
    valCols = Param(Params._dummy(), "valCols", "",TypeConverters.toListString)

    def __init__(self,aggCols,valCols):
        super(AggregateTransformer, self).__init__()
        self._setDefault(aggCols=[''])
        self._set(aggCols = aggCols)
        self._setDefault(valCols=[''])
        self._set(valCols = valCols)

    def getAggCols(self):
        return self.getOrDefault(self.aggCols)

    def setAggCols(self, aggCols):
        self._set(aggCols=aggCols)

    def getValCols(self):
        return self.getOrDefault(self.valCols)

    def setValCols(self, valCols):
        self._set(valCols=valCols)

    def _transform(self, dataset):
        aggFuncs = []
        for valCol in self.getValCols():
            aggFuncs.append(F.sum(valCol).alias("sum_"+valCol))
            aggFuncs.append(F.min(valCol).alias("min_"+valCol))
            aggFuncs.append(F.max(valCol).alias("max_"+valCol))
            aggFuncs.append(F.count(valCol).alias("cnt_"+valCol))
            aggFuncs.append(F.avg(valCol).alias("avg_"+valCol))
            aggFuncs.append(F.stddev(valCol).alias("stddev_"+valCol))

        dataset = dataset.groupBy(self.getAggCols()).agg(*aggFuncs)
        return dataset

I get this error when I load an instance of this transformer after saving it.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-172-44e20f7e3842> in <module>()
----> 1 x = agg.load("/tmp/test")

/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(cls, path)
    309     def load(cls, path):
    310         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 311         return cls.read().load(path)
    312 
    313 

/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(self, path)
    482         metadata = DefaultParamsReader.loadMetadata(path, self.sc)
    483         py_type = DefaultParamsReader.__get_class(metadata['class'])
--> 484         instance = py_type()
    485         instance._resetUid(metadata['uid'])
    486         DefaultParamsReader.getAndSetParams(instance, metadata)

TypeError: __init__() missing 2 required positional arguments: 'aggCols' and 'valCols'

Upvotes: 7

Views: 1235

Answers (1)

Figured out the answer!

The problem was that a new Transformer class was being initialized by the reader but the init function for my AggregateTransformer didnt have default values for the arguments.

So changing the following line of code fixed the issue!

def __init__(self,aggCols=[],valCols=[]):

Going to leave this answer and question here since it was incredibly difficult for me to find a working example of a pure python transformer that could be saved and read back anywhere! It could help someone looking for this.

Upvotes: 5

Related Questions