Reputation: 343
I have a dataframe with a column named 'counts' and I would like to apply a custom function "do_something" to each of the elements of the column, meaning each array. I do not want to modify the dataframe, I just want to do a separate operation with the column counts. All arrays of the column have the same size.
+----------------------+---------------------------------------+
|id| counts|
+----------------------+---------------------------------------+
|1| [8.0, 2.0, 3.0|
|2| [1.0, 6.0, 3.0|
+----------------------+---------------------------------------+
When I am trying this:
df.select('counts').rdd.foreach(lambda x: do_something(x))
even if i try without lambda it gives the same error.
it fails on the line above with
Py4JJavaError Traceback (most recent call last) in () ----> 1 df.select('counts').rdd.foreach(lambda x: do_something(x))
/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in foreach(self, f) 745 f(x) 746 return iter([]) --> 747 self.mapPartitions(processPartition).count() # Force evaluation 748 749 def foreachPartition(self, f):
/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in count(self) 1002 3 1003 """ -> 1004 return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() 1005 1006 def stats(self):
/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in sum(self) 993 6.0 994 """ --> 995 return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add) 996 997 def count(self):
/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in fold(self, zeroValue, op) 867 # zeroValue provided to each partition is unique from the one provided 868 # to the final reduce call --> 869 vals = self.mapPartitions(func).collect() 870 return reduce(op, vals, zeroValue) 871
/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in collect(self) 769 """ 770 with SCCallSiteSync(self.context) as css: --> 771 port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) 772 return list(_load_from_socket(port, self._jrdd_deserializer)) 773
/usr/hdp/2.5.3.0-37/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in call(self, *args) 811 answer = self.gateway_client.send_command(command) 812 return_value = get_return_value( --> 813 answer, self.gateway_client, self.target_id, self.name) 814 815 for temp_arg in temp_args:
/usr/hdp/2.5.3.0-37/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 43 def deco(*a, **kw): 44 try: ---> 45 return f(*a, **kw) 46 except py4j.protocol.Py4JJavaError as e: 47 s = e.java_exception.toString()
/usr/hdp/2.5.3.0-37/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 306 raise Py4JJavaError( 307 "An error occurred while calling {0}{1}{2}.\n". --> 308 format(target_id, ".", name), value) 309 else: 310 raise Py4JError(
although all input arrays have the same size.
big_list=[]
def do_something(i_array):
outputs = custom_library(i_array) # takes as input an array and returns 3 new lists
big_list.extend(outputs)
Upvotes: 1
Views: 2805
Reputation: 10086
Your UDF
modifies a python object, that is :
You can try doing this instead:
def do_something(i_array):
outputs = custom_library(i_array)
return outputs
import pyspark.sql.functions as psf
do_something_udf = psf.udf(do_something, ArrayType(ArrayType(DoubleType()))
DoubleType()
or whichever type you return
df.withColumn("outputs", psf.explode(do_something_udf("count")))
You'll have three times as many rows as df
Upvotes: 2