Reputation: 3534
I have the following DataFrame ordered by group
, n1
, n2
+-----+--+--+------+------+
|group|n1|n2|n1_ptr|n2_ptr|
+-----+--+--+------+------+
| 1| 0| 0| 1| 1|
| 1| 1| 1| 2| 2|
| 1| 1| 5| 2| 6|
| 1| 2| 2| 3| 3|
| 1| 2| 6| 3| 7|
| 1| 3| 3| 4| 4|
| 1| 3| 7| null| null|
| 1| 4| 4| 5| 5|
| 1| 5| 1| null| null|
| 1| 5| 5| null| null|
+-----+--+--+------+------+
Each row's n1_ptr
and n2_ptr
values refer to the n1
and n2
values of some other row in the group that comes later in the ordering. In other words, n1_ptr
and n2_ptr
are effectively pointers to another row. I want to use these pointers to identify chains of (n1, n2)
pairs. For example, the chains in the given data would be: (0,0)
-> (1,1)
-> (2,2)
-> (3,3)
-> (4,4)
-> (5,5)
; (1,5)
-> (2,6)
-> (3,7)
; and (5,1)
.
The ultimate goal is to consolidate each chain into a single row in a DataFrame describing the min and max n1
and n2
values in each chain. Continuing the example, this would yield
+-----+------+------+------+------+
|group|n1_min|n2_min|n1_max|n2_max|
+-----+------+------+------+------+
| 1| 0| 0| 5| 5|
| 1| 1| 5| 3| 7|
| 1| 5| 1| 5| 1|
+-----+------+------+------+------+
It seems like a udf might do the trick, but I am concerned about performance. Is there a more sensible/performant way to go about this?
Upvotes: 0
Views: 911
Reputation: 10076
A good solution would be to use graphframes
: https://graphframes.github.io/quick-start.html.
First let's change the structure of your initial dataframe:
import pyspark.sql.functions as psf
df = sc.parallelize([[1, 0, 0, 1, 1],[1, 1, 1, 2, 2],[1, 1, 5, 2, 6],
[1, 2, 2, 3, 3],[1, 2, 6, 3, 7],[1, 3, 3, 4, 4],
[1, 3, 7, None, None],[1, 4, 4, 5, 5],[1, 5, 1, None, None],
[1, 5, 5, None, None]]).toDF(["group","n1","n2","n1_ptr","n2_ptr"]).filter("n1_ptr IS NOT NULL")
df = df.select(
"group",
psf.struct("n1", "n2").alias("src"),
psf.struct(df.n1_ptr.alias("n1"), df.n2_ptr.alias("n2")).alias("dst"))
From df
we'll build a vertex and an edge dataframe:
v = df.select(
"group",
psf.explode(psf.array("src", "dst")).alias("id"))
e = df.drop("group")
The next step is to find all connected components using graphframes
:
from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()
+-----+-----+------------+
|group| id| component|
+-----+-----+------------+
| 1|[0,0]|309237645312|
| 1|[1,1]|309237645312|
| 1|[1,1]|309237645312|
| 1|[2,2]|309237645312|
| 1|[1,5]| 85899345920|
| 1|[2,6]| 85899345920|
| 1|[2,2]|309237645312|
| 1|[3,3]|309237645312|
| 1|[2,6]| 85899345920|
| 1|[3,7]| 85899345920|
| 1|[3,3]|309237645312|
| 1|[4,4]|309237645312|
| 1|[3,7]| 85899345920|
| 1|[4,4]|309237645312|
| 1|[5,5]|309237645312|
| 1|[5,1]|292057776128|
| 1|[5,5]|309237645312|
+-----+-----+------------+
Now since the relation in your graph edges implies that nodes numbers n1
and n2
are monotonically increasing, we can simply aggregate by component and compute the min
and max
:
res.groupBy("group", "component").agg(
psf.min("id").alias("min_id"),
psf.max("id").alias("max_id")
)
+-----+------------+------+------+
|group| component|min_id|max_id|
+-----+------------+------+------+
| 1|309237645312| [0,0]| [5,5]|
| 1| 85899345920| [1,5]| [3,7]|
| 1|292057776128| [5,1]| [5,1]|
+-----+------------+------+------+
Upvotes: 2