alan
alan

Reputation: 3534

[Py]Spark SQL: Merge two or more rows based on equal values in different columns

I have the following DataFrame ordered by group, n1, n2

+-----+--+--+------+------+                                        
|group|n1|n2|n1_ptr|n2_ptr|                                
+-----+--+--+------+------+                                        
|    1| 0| 0|     1|     1|                                        
|    1| 1| 1|     2|     2|                                        
|    1| 1| 5|     2|     6|                                        
|    1| 2| 2|     3|     3|                                        
|    1| 2| 6|     3|     7|                                        
|    1| 3| 3|     4|     4|                                        
|    1| 3| 7|  null|  null|                                        
|    1| 4| 4|     5|     5|                                        
|    1| 5| 1|  null|  null|                                        
|    1| 5| 5|  null|  null|                                        
+-----+--+--+------+------+

Each row's n1_ptr and n2_ptr values refer to the n1 and n2 values of some other row in the group that comes later in the ordering. In other words, n1_ptr and n2_ptr are effectively pointers to another row. I want to use these pointers to identify chains of (n1, n2) pairs. For example, the chains in the given data would be: (0,0) -> (1,1) -> (2,2) -> (3,3) -> (4,4) -> (5,5); (1,5) -> (2,6) -> (3,7); and (5,1).

The ultimate goal is to consolidate each chain into a single row in a DataFrame describing the min and max n1 and n2 values in each chain. Continuing the example, this would yield

+-----+------+------+------+------+
|group|n1_min|n2_min|n1_max|n2_max|
+-----+------+------+------+------+       
|    1|     0|     0|     5|     5|
|    1|     1|     5|     3|     7|
|    1|     5|     1|     5|     1| 
+-----+------+------+------+------+

It seems like a udf might do the trick, but I am concerned about performance. Is there a more sensible/performant way to go about this?

Upvotes: 0

Views: 911

Answers (1)

MaFF
MaFF

Reputation: 10076

A good solution would be to use graphframes: https://graphframes.github.io/quick-start.html.

First let's change the structure of your initial dataframe:

import pyspark.sql.functions as psf
df = sc.parallelize([[1, 0, 0, 1, 1],[1, 1, 1, 2, 2],[1, 1, 5, 2, 6],
                     [1, 2, 2, 3, 3],[1, 2, 6, 3, 7],[1, 3, 3, 4, 4],
                     [1, 3, 7, None, None],[1, 4, 4, 5, 5],[1, 5, 1, None, None],
                     [1, 5, 5, None, None]]).toDF(["group","n1","n2","n1_ptr","n2_ptr"]).filter("n1_ptr IS NOT NULL")
df = df.select(
    "group",
    psf.struct("n1", "n2").alias("src"), 
    psf.struct(df.n1_ptr.alias("n1"), df.n2_ptr.alias("n2")).alias("dst"))

From df we'll build a vertex and an edge dataframe:

v = df.select(
    "group", 
    psf.explode(psf.array("src", "dst")).alias("id"))
e = df.drop("group")

The next step is to find all connected components using graphframes:

from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()

    +-----+-----+------------+
    |group|   id|   component|
    +-----+-----+------------+
    |    1|[0,0]|309237645312|
    |    1|[1,1]|309237645312|
    |    1|[1,1]|309237645312|
    |    1|[2,2]|309237645312|
    |    1|[1,5]| 85899345920|
    |    1|[2,6]| 85899345920|
    |    1|[2,2]|309237645312|
    |    1|[3,3]|309237645312|
    |    1|[2,6]| 85899345920|
    |    1|[3,7]| 85899345920|
    |    1|[3,3]|309237645312|
    |    1|[4,4]|309237645312|
    |    1|[3,7]| 85899345920|
    |    1|[4,4]|309237645312|
    |    1|[5,5]|309237645312|
    |    1|[5,1]|292057776128|
    |    1|[5,5]|309237645312|
    +-----+-----+------------+

Now since the relation in your graph edges implies that nodes numbers n1 and n2 are monotonically increasing, we can simply aggregate by component and compute the min and max:

res.groupBy("group", "component").agg(
    psf.min("id").alias("min_id"), 
    psf.max("id").alias("max_id")
)

    +-----+------------+------+------+
    |group|   component|min_id|max_id|
    +-----+------------+------+------+
    |    1|309237645312| [0,0]| [5,5]|
    |    1| 85899345920| [1,5]| [3,7]|
    |    1|292057776128| [5,1]| [5,1]|
    +-----+------------+------+------+

Upvotes: 2

Related Questions