Vas
Vas

Reputation: 343

apply custom function to a column of array type ofdataframe

I have a dataframe with a column named 'counts' and I would like to apply a custom function "do_something" to each of the elements of the column, meaning each array. I do not want to modify the dataframe, I just want to do a separate operation with the column counts. All arrays of the column have the same size.

+----------------------+---------------------------------------+
|id|              counts|
+----------------------+---------------------------------------+
|1|          [8.0, 2.0, 3.0|
|2|          [1.0, 6.0, 3.0|                
+----------------------+---------------------------------------+

When I am trying this:

df.select('counts').rdd.foreach(lambda x: do_something(x))

even if i try without lambda it gives the same error.

it fails on the line above with

Py4JJavaError Traceback (most recent call last) in () ----> 1 df.select('counts').rdd.foreach(lambda x: do_something(x))

/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in foreach(self, f) 745 f(x) 746 return iter([]) --> 747 self.mapPartitions(processPartition).count() # Force evaluation 748 749 def foreachPartition(self, f):

/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in count(self) 1002 3 1003 """ -> 1004 return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() 1005 1006 def stats(self):

/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in sum(self) 993 6.0 994 """ --> 995 return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add) 996 997 def count(self):

/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in fold(self, zeroValue, op) 867 # zeroValue provided to each partition is unique from the one provided 868 # to the final reduce call --> 869 vals = self.mapPartitions(func).collect() 870 return reduce(op, vals, zeroValue) 871

/usr/hdp/2.5.3.0-37/spark/python/pyspark/rdd.py in collect(self) 769 """ 770 with SCCallSiteSync(self.context) as css: --> 771 port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) 772 return list(_load_from_socket(port, self._jrdd_deserializer)) 773

/usr/hdp/2.5.3.0-37/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in call(self, *args) 811 answer = self.gateway_client.send_command(command) 812 return_value = get_return_value( --> 813 answer, self.gateway_client, self.target_id, self.name) 814 815 for temp_arg in temp_args:

/usr/hdp/2.5.3.0-37/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 43 def deco(*a, **kw): 44 try: ---> 45 return f(*a, **kw) 46 except py4j.protocol.Py4JJavaError as e: 47 s = e.java_exception.toString()

/usr/hdp/2.5.3.0-37/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 306 raise Py4JJavaError( 307 "An error occurred while calling {0}{1}{2}.\n". --> 308 format(target_id, ".", name), value) 309 else: 310 raise Py4JError(

although all input arrays have the same size.

big_list=[]
def do_something(i_array):
    outputs = custom_library(i_array) # takes as input an array and returns 3 new lists
    big_list.extend(outputs)

Upvotes: 1

Views: 2805

Answers (1)

MaFF
MaFF

Reputation: 10086

Your UDF modifies a python object, that is :

  • exterior to the dataframe, even if the function worked you wouldn't be able to access the value since you're not returning it to the rows of your dataframe
  • huge, it will have at least three times as many elements as the number of rows in your dataframe

You can try doing this instead:

def do_something(i_array):
    outputs = custom_library(i_array)
    return outputs

import pyspark.sql.functions as psf
do_something_udf = psf.udf(do_something, ArrayType(ArrayType(DoubleType()))

DoubleType() or whichever type you return

df.withColumn("outputs", psf.explode(do_something_udf("count")))

You'll have three times as many rows as df

Upvotes: 2

Related Questions