Reputation: 163
I have the following code:
elements = spark.createDataFrame([
('g1', 'a', 1), ('g1', 'a', 2), ('g1', 'b', 1), ('g1', 'b', 3),
('g2', 'c', 1), ('g2', 'c', 3), ('g2', 'c', 2), ('g2', 'd', 4),
], ['group', 'instance', 'element'])
all_elements_per_instance = elements.groupBy("group", "instance").agg(f.collect_set('element').alias('elements'))
# +-----+--------+---------+
# |group|instance| elements|
# +-----+--------+---------+
# | g1| b| [1, 3]|
# | g1| a| [1, 2]|
# | g2| c|[1, 2, 3]|
# | g2| d| [4]|
# +-----+--------+---------+
@f.udf(ArrayType(IntegerType()))
def intersect(elements):
return list(functools.reduce(lambda x, y: set(x).intersection(set(y)), elements))
all_intersect_elements_per_group = all_elements_per_instance.groupBy("group")\
.agg(intersect(f.collect_list("elements")).alias("intersection"))
# +-----+------------+
# |group|intersection|
# +-----+------------+
# | g1| [1]|
# | g2| []|
# +-----+------------+
Is there a way to avoid the UDF (as it's costly), and use somehow the f.array_intersect
or similar function as an aggregation function?
Upvotes: 1
Views: 1099
Reputation: 32670
If your intent is to find the list of elements
that are shared at least by 2 instances
in each group
, you can actually simplify it by counting distinct instances for each group/element using a Window then groupby group
and collect only elements that have count > 1:
from pyspark.sql import Window
from pyspark.sql import functions as F
result = elements.withColumn(
"cnt",
F.size(F.collect_set("instance").over(Window.partitionBy("group", "element")))
).groupBy("group").agg(
F.collect_set(
F.when(F.col("cnt") > 1, F.col("element"))
).alias('intersection')
)
result.show()
#+-----+------------+
#|group|intersection|
#+-----+------------+
#| g2| []|
#| g1| [1]|
#+-----+------------+
I used collect_set
+ size
as the function countDistinct
dosn't support Window.
Upvotes: 1
Reputation: 42352
You can use the higher-order function aggregate
to do an array_intersect
on the elements:
import pyspark.sql.functions as f
result = all_elements_per_instance.groupBy('group').agg(
f.expr("""
aggregate(
collect_list(elements),
collect_list(elements)[0],
(acc, x) -> array_intersect(acc, x)
) as intersection
""")
)
result.show()
+-----+------------+
|group|intersection|
+-----+------------+
| g2| []|
| g1| [1]|
+-----+------------+
Upvotes: 3