Siddharth Satpathy
Siddharth Satpathy

Reputation: 3043

How to find average of array of arrays in PySpark

I have the a PySpark Dataframe in which one of the columns (say B) is an array of arrays. Following is the PySpark dataframe:

+---+-----------------------------+---+
|A  |B                            |C  |
+---+-----------------------------+---+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |
|a  |[[5.0], [20.0, 80.0]]        |d  |
|a  |[[5.0], [25.0, 75.0]]        |e  |
|b  |[[5.0], [25.0, 75.0]]        |f  |
|b  |[[5.0], [12.0, 88.0]]        |g  |
+---+-----------------------------+---+

I want to find the number of elements and the average of all elements (as separate columns) for each row.

Below is the expected output:

+---+-----------------------------+---+---+------+
|A  |B                            |C  |Num|   Avg|
+---+-----------------------------+---+---+------+
|a  |[[5.0], [25.0, 25.0], [40.0]]|c  |4  | 23.75|
|a  |[[5.0], [20.0, 80.0]]        |d  |3  | 35.00|
|a  |[[5.0], [25.0, 75.0]]        |e  |3  | 35.00|
|b  |[[5.0], [25.0, 75.0]]        |f  |3  | 35.00|
|b  |[[5.0], [12.0, 88.0]]        |g  |3  | 35.00|
+---+-----------------------------+---+---+------+

What is an efficient way to find averages of all elements in array of arrays (in each row) in PySpark?

Presently, I am using an udf to do these. Below is the code that I have at present:

from pyspark.sql import functions as F
import pyspark.sql.types as T
from pyspark.sql import *
from pyspark.sql.types import DecimalType
from pyspark.sql.functions import udf
import numpy as np

#UDF to find number of elements
def len_array_of_arrays(anomaly_in_issue_group_col):
    return sum([len(array_element) for array_element in anomaly_in_issue_group_col])

udf_len_array_of_arrays = F.udf( len_array_of_arrays , T.IntegerType() )

#UDF to find average of all elements
def avg_array_of_arrays(anomaly_in_issue_group_col):
    return np.mean( [ element for array_element in anomaly_in_issue_group_col for element in array_element] )

udf_avg_array_of_arrays = F.udf( avg_array_of_arrays , T.DecimalType() )

df.withColumn("Num", udf_len_array_of_arrays(F.col("B"))).withColumn(
    "Avg", udf_avg_array_of_arrays(F.col("B"))
).show(20, False)

The udf for finding the number of elements in each row works. But, the udf for finding the averages throws the following error:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-176-3253feca2963> in <module>()
      1 #df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).show(20, False)
----> 2 df.withColumn("Num" , udf_len_array_of_arrays(F.col("B")) ).withColumn("Avg" ,  udf_avg_array_of_arrays(F.col("B")) ).show(20, False)

/usr/lib/spark/python/pyspark/sql/dataframe.py in show(self, n, truncate, vertical)
    378             print(self._jdf.showString(n, 20, vertical))
    379         else:
--> 380             print(self._jdf.showString(n, int(truncate), vertical))
    381 
    382     def __repr__(self):

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Upvotes: 0

Views: 2356

Answers (2)

Oliver W.
Oliver W.

Reputation: 13459

Since Spark 1.4:

explode() the column containing arrays, as many times as there are nesting levels. Use monotonically_increasing_id() to create an extra grouping key to prevent duplicate rows from being combined:

from pyspark.sql.functions import explode, sum, lit, avg, monotonically_increasing_id

df = spark.createDataFrame(
    (("a", [[1], [2, 3], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),
     ("a", [[5], [6, 0], [4]], "foo"),  # DUPE!
     ("b", [[2, 3], [4]], "foo")),
    schema=("category", "arrays", "foo"))

df2 = (df.withColumn("id", monotonically_increasing_id())
       .withColumn("subarray", explode("arrays"))
       .withColumn("subarray", explode("subarray"))  # unnest another level
       .groupBy("category", "arrays", "foo", "id")
       .agg(sum(lit(1)).alias("number_of_elements"),
            avg("subarray").alias("avg")).drop("id"))
df2.show()
# +--------+------------------+---+------------------+----+  
# |category|            arrays|foo|number_of_elements| avg|
# +--------+------------------+---+------------------+----+
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       b|     [[2, 3], [4]]|foo|                 3| 3.0|
# |       a|[[5], [6, 0], [4]]|foo|                 4|3.75|
# |       a|[[1], [2, 3], [4]]|foo|                 4| 2.5|
# +--------+------------------+---+------------------+----+

Spark 2.4 has seen the introduction of 24 functions that deal with complex types, along with higher order functions (functions that take functions as an argument, like Python 3’s functools.reduce). They take away the boilerplate that you see above. If you’re on Spark2.4+, see the answer from jxc.

Upvotes: 1

jxc
jxc

Reputation: 14008

For spark 2.4+, use flatten + aggregate:

from pyspark.sql.functions import expr

df.withColumn("Avg", expr("""
    aggregate(
        flatten(B)
      , (double(0) as total, int(0) as cnt)
      , (x,y) -> (x.total+y, x.cnt+1)
      , z -> round(z.total/z.cnt,2)
    ) 
 """)).show()
+-----------------------------+---+-----+
|B                            |C  |Avg  |
+-----------------------------+---+-----+
|[[5.0], [25.0, 25.0], [40.0]]|c  |23.75|
|[[5.0], [25.0, 80.0]]        |d  |36.67|
|[[5.0], [25.0, 75.0]]        |e  |35.0 |
+-----------------------------+---+-----+

Upvotes: 2

Related Questions