Reputation: 879
After messing with this for quite a while, in Spark 2.3 I am finally able to get a pure python custom transformer saved. But I get an error while loading the transformer back.
I checked the content of what was saved and find all the relevant variable saved in the file on HDFS. Would be great if someone can spot what I am missing to do in this simple transformer.
from pyspark.ml import Transformer
from pyspark.ml.param.shared import Param,Params,TypeConverters
class AggregateTransformer(Transformer,DefaultParamsWritable,DefaultParamsReadable):
aggCols = Param(Params._dummy(), "aggCols", "",TypeConverters.toListString)
valCols = Param(Params._dummy(), "valCols", "",TypeConverters.toListString)
def __init__(self,aggCols,valCols):
super(AggregateTransformer, self).__init__()
self._setDefault(aggCols=[''])
self._set(aggCols = aggCols)
self._setDefault(valCols=[''])
self._set(valCols = valCols)
def getAggCols(self):
return self.getOrDefault(self.aggCols)
def setAggCols(self, aggCols):
self._set(aggCols=aggCols)
def getValCols(self):
return self.getOrDefault(self.valCols)
def setValCols(self, valCols):
self._set(valCols=valCols)
def _transform(self, dataset):
aggFuncs = []
for valCol in self.getValCols():
aggFuncs.append(F.sum(valCol).alias("sum_"+valCol))
aggFuncs.append(F.min(valCol).alias("min_"+valCol))
aggFuncs.append(F.max(valCol).alias("max_"+valCol))
aggFuncs.append(F.count(valCol).alias("cnt_"+valCol))
aggFuncs.append(F.avg(valCol).alias("avg_"+valCol))
aggFuncs.append(F.stddev(valCol).alias("stddev_"+valCol))
dataset = dataset.groupBy(self.getAggCols()).agg(*aggFuncs)
return dataset
I get this error when I load an instance of this transformer after saving it.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-172-44e20f7e3842> in <module>()
----> 1 x = agg.load("/tmp/test")
/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(cls, path)
309 def load(cls, path):
310 """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 311 return cls.read().load(path)
312
313
/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(self, path)
482 metadata = DefaultParamsReader.loadMetadata(path, self.sc)
483 py_type = DefaultParamsReader.__get_class(metadata['class'])
--> 484 instance = py_type()
485 instance._resetUid(metadata['uid'])
486 DefaultParamsReader.getAndSetParams(instance, metadata)
TypeError: __init__() missing 2 required positional arguments: 'aggCols' and 'valCols'
Upvotes: 7
Views: 1235
Reputation: 879
Figured out the answer!
The problem was that a new Transformer class was being initialized by the reader but the init function for my AggregateTransformer didnt have default values for the arguments.
So changing the following line of code fixed the issue!
def __init__(self,aggCols=[],valCols=[]):
Going to leave this answer and question here since it was incredibly difficult for me to find a working example of a pure python transformer that could be saved and read back anywhere! It could help someone looking for this.
Upvotes: 5