Reputation: 3043
I have the a PySpark Dataframe in which one of the columns (say B
) is an array of arrays. Following is the PySpark dataframe:
+---+-----------------------------+---+
|A |B |C |
+---+-----------------------------+---+
|a |[[5.0], [25.0, 25.0], [40.0]]|c |
|a |[[5.0], [20.0, 80.0]] |d |
|a |[[5.0], [25.0, 75.0]] |e |
|b |[[5.0], [25.0, 75.0]] |f |
|b |[[5.0], [12.0, 88.0]] |g |
+---+-----------------------------+---+
I want to find the number of elements and the average of all elements (as separate columns) for each row.
Below is the expected output:
+---+-----------------------------+---+---+------+
|A |B |C |Num| Avg|
+---+-----------------------------+---+---+------+
|a |[[5.0], [25.0, 25.0], [40.0]]|c |4 | 23.75|
|a |[[5.0], [20.0, 80.0]] |d |3 | 35.00|
|a |[[5.0], [25.0, 75.0]] |e |3 | 35.00|
|b |[[5.0], [25.0, 75.0]] |f |3 | 35.00|
|b |[[5.0], [12.0, 88.0]] |g |3 | 35.00|
+---+-----------------------------+---+---+------+
What is an efficient way to find averages of all elements in array of arrays (in each row) in PySpark?
Presently, I am using an udf to do these. Below is the code that I have at present:
from pyspark.sql import functions as F
import pyspark.sql.types as T
from pyspark.sql import *
from pyspark.sql.types import DecimalType
from pyspark.sql.functions import udf
import numpy as np
#UDF to find number of elements
def len_array_of_arrays(anomaly_in_issue_group_col):
return sum([len(array_element) for array_element in anomaly_in_issue_group_col])
udf_len_array_of_arrays = F.udf( len_array_of_arrays , T.IntegerType() )
#UDF to find average of all elements
def avg_array_of_arrays(anomaly_in_issue_group_col):
return np.mean( [ element for array_element in anomaly_in_issue_group_col for element in array_element] )
udf_avg_array_of_arrays = F.udf( avg_array_of_arrays , T.DecimalType() )
df.withColumn("Num", udf_len_array_of_arrays(F.col("B"))).withColumn(
"Avg", udf_avg_array_of_arrays(F.col("B"))
).show(20, False)
The udf for finding the number of elements in each row works. But, the udf for finding the averages throws the following error:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-176-3253feca2963> in <module>()
1 #df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).show(20, False)
----> 2 df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).withColumn("Avg" , udf_avg_array_of_arrays(F.col("B")) ).show(20, False)
/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
378 print(self._jdf.showString(n, 20, vertical))
379 else:
--> 380 print(self._jdf.showString(n, int(truncate), vertical))
381
382 def __repr__(self):
/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
61 def deco(*a, **kw):
62 try:
---> 63 return f(*a, **kw)
64 except py4j.protocol.Py4JJavaError as e:
65 s = e.java_exception.toString()
/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Upvotes: 0
Views: 2356
Reputation: 13459
Since Spark 1.4:
explode()
the column containing arrays, as many times as there are nesting levels. Use monotonically_increasing_id()
to create an extra grouping key to prevent duplicate rows from being combined:
from pyspark.sql.functions import explode, sum, lit, avg, monotonically_increasing_id
df = spark.createDataFrame(
(("a", [[1], [2, 3], [4]], "foo"),
("a", [[5], [6, 0], [4]], "foo"),
("a", [[5], [6, 0], [4]], "foo"), # DUPE!
("b", [[2, 3], [4]], "foo")),
schema=("category", "arrays", "foo"))
df2 = (df.withColumn("id", monotonically_increasing_id())
.withColumn("subarray", explode("arrays"))
.withColumn("subarray", explode("subarray")) # unnest another level
.groupBy("category", "arrays", "foo", "id")
.agg(sum(lit(1)).alias("number_of_elements"),
avg("subarray").alias("avg")).drop("id"))
df2.show()
# +--------+------------------+---+------------------+----+
# |category| arrays|foo|number_of_elements| avg|
# +--------+------------------+---+------------------+----+
# | a|[[5], [6, 0], [4]]|foo| 4|3.75|
# | b| [[2, 3], [4]]|foo| 3| 3.0|
# | a|[[5], [6, 0], [4]]|foo| 4|3.75|
# | a|[[1], [2, 3], [4]]|foo| 4| 2.5|
# +--------+------------------+---+------------------+----+
Spark 2.4 has seen the introduction of 24 functions that deal with complex types, along with higher order functions (functions that take functions as an argument, like Python 3’s functools.reduce
). They take away the boilerplate that you see above. If you’re on Spark2.4+, see the answer from jxc.
Upvotes: 1
Reputation: 14008
For spark 2.4+, use flatten + aggregate:
from pyspark.sql.functions import expr
df.withColumn("Avg", expr("""
aggregate(
flatten(B)
, (double(0) as total, int(0) as cnt)
, (x,y) -> (x.total+y, x.cnt+1)
, z -> round(z.total/z.cnt,2)
)
""")).show()
+-----------------------------+---+-----+
|B |C |Avg |
+-----------------------------+---+-----+
|[[5.0], [25.0, 25.0], [40.0]]|c |23.75|
|[[5.0], [25.0, 80.0]] |d |36.67|
|[[5.0], [25.0, 75.0]] |e |35.0 |
+-----------------------------+---+-----+
Upvotes: 2