Litan Ilany
Litan Ilany

Reputation: 163

How can I conduct an intersection of multiple arrays into single array on PySpark, without UDF?

I have the following code:

elements = spark.createDataFrame([
('g1', 'a', 1), ('g1', 'a', 2), ('g1', 'b', 1), ('g1', 'b', 3),
('g2', 'c', 1), ('g2', 'c', 3), ('g2', 'c', 2), ('g2', 'd', 4),
], ['group', 'instance', 'element'])



all_elements_per_instance = elements.groupBy("group", "instance").agg(f.collect_set('element').alias('elements'))


# +-----+--------+---------+
# |group|instance| elements|
# +-----+--------+---------+
# |   g1|       b|   [1, 3]|
# |   g1|       a|   [1, 2]|
# |   g2|       c|[1, 2, 3]|
# |   g2|       d|      [4]|
# +-----+--------+---------+

@f.udf(ArrayType(IntegerType()))
def intersect(elements):
    return list(functools.reduce(lambda x, y: set(x).intersection(set(y)), elements))

all_intersect_elements_per_group = all_elements_per_instance.groupBy("group")\
    .agg(intersect(f.collect_list("elements")).alias("intersection"))

# +-----+------------+
# |group|intersection|
# +-----+------------+
# |   g1|         [1]|
# |   g2|          []|
# +-----+------------+

Is there a way to avoid the UDF (as it's costly), and use somehow the f.array_intersect or similar function as an aggregation function?

Upvotes: 1

Views: 1099

Answers (2)

blackbishop
blackbishop

Reputation: 32670

If your intent is to find the list of elements that are shared at least by 2 instances in each group, you can actually simplify it by counting distinct instances for each group/element using a Window then groupby group and collect only elements that have count > 1:

from pyspark.sql import Window
from pyspark.sql import functions as F

result = elements.withColumn(
    "cnt",
    F.size(F.collect_set("instance").over(Window.partitionBy("group", "element")))
).groupBy("group").agg(
    F.collect_set(
        F.when(F.col("cnt") > 1, F.col("element"))
    ).alias('intersection')
)

result.show()

#+-----+------------+
#|group|intersection|
#+-----+------------+
#|   g2|          []|
#|   g1|         [1]|
#+-----+------------+

I used collect_set + size as the function countDistinct dosn't support Window.

Upvotes: 1

mck
mck

Reputation: 42352

You can use the higher-order function aggregate to do an array_intersect on the elements:

import pyspark.sql.functions as f
result = all_elements_per_instance.groupBy('group').agg(
    f.expr("""
        aggregate(
            collect_list(elements),
            collect_list(elements)[0],
            (acc, x) -> array_intersect(acc, x)
        ) as intersection
    """)
)

result.show()
+-----+------------+
|group|intersection|
+-----+------------+
|   g2|          []|
|   g1|         [1]|
+-----+------------+

Upvotes: 3

Related Questions