Reputation: 3534
Given a positive Long i
and a DataFrame
+-----+--+--+
|group|n1|n2|
+-----+--+--+
| 1| 0| 0|
| 1| 1| 1|
| 1| 1| 5|
| 1| 2| 2|
| 1| 2| 6|
| 1| 3| 3|
| 1| 3| 7|
| 1| 4| 4|
| 1| 5| 1|
| 1| 5| 5|
+-----+--+--+
how would you sessionize rows in the same group
such that for each pair of consecutive rows r1
, r2
in a session, r2.n1
> r1.n1
, r2.n2
> r1.n2
, and max(r2.n1
- r1.n1
, r2.n2
- r1.n2
) < i
? Note, n1
and n2
values may not be unique, meaning rows that make up a session may not be consecutive in the DataFrame.
As an example, the result for the given DataFrame and i
=3 would be
+-----+--+--+-------+
|group|n1|n2|session|
+-----+--+--+-------+
| 1| 0| 0| 1|
| 1| 1| 1| 1|
| 1| 1| 5| 2|
| 1| 2| 2| 1|
| 1| 2| 6| 2|
| 1| 3| 3| 1|
| 1| 3| 7| 2|
| 1| 4| 4| 1|
| 1| 5| 1| 3|
| 1| 5| 5| 1|
+-----+--+--+-------+
Any help or hints will be greatly appreciated. Thanks!
Upvotes: 2
Views: 261
Reputation: 10096
This looks like you're trying to mark with a same number all connected parts of a graph. A good solution would be to use graphframes
: https://graphframes.github.io/quick-start.html
From your dataframe:
df = sc.parallelize([[1, 0, 0],[1, 1, 1],[1, 1, 5],[1, 2, 2],[1, 2, 6],
[1, 3, 3],[1, 3, 7],[1, 4, 4],[1, 5, 1],[1, 5, 5]]).toDF(["group","n1","n2"])
We'll create a vertex dataframe containing the list of unique id
s:
import pyspark.sql.functions as psf
v = df.select(psf.struct("n1", "n2").alias("id"), "group")
+-----+-----+
| id|group|
+-----+-----+
|[0,0]| 1|
|[1,1]| 1|
|[1,5]| 1|
|[2,2]| 1|
|[2,6]| 1|
|[3,3]| 1|
|[3,7]| 1|
|[4,4]| 1|
|[5,1]| 1|
|[5,5]| 1|
+-----+-----+
And an edge dataframe defined from the boolean condition you stated:
i = 3
e = df.alias("r1").join(
df.alias("r2"),
(psf.col("r1.group") == psf.col("r2.group"))
& (psf.col("r1.n1") < psf.col("r2.n1"))
& (psf.col("r1.n2") < psf.col("r2.n2"))
& (psf.greatest(
psf.col("r2.n1") - psf.col("r1.n1"),
psf.col("r2.n2") - psf.col("r1.n2")) < i)
).select(psf.struct("r1.n1", "r1.n2").alias("src"), psf.struct("r2.n1", "r2.n2").alias("dst"))
+-----+-----+
| src| dst|
+-----+-----+
|[0,0]|[1,1]|
|[0,0]|[2,2]|
|[1,1]|[2,2]|
|[1,1]|[3,3]|
|[1,5]|[2,6]|
|[1,5]|[3,7]|
|[2,2]|[3,3]|
|[2,2]|[4,4]|
|[2,6]|[3,7]|
|[3,3]|[4,4]|
|[3,3]|[5,5]|
|[4,4]|[5,5]|
+-----+-----+
And now to find all connected components:
from graphframes import *
g = GraphFrame(v, e)
res = g.connectedComponents()
+-----+-----+------------+
| id|group| component|
+-----+-----+------------+
|[0,0]| 1|309237645312|
|[1,1]| 1|309237645312|
|[1,5]| 1| 85899345920|
|[2,2]| 1|309237645312|
|[2,6]| 1| 85899345920|
|[3,3]| 1|309237645312|
|[3,7]| 1| 85899345920|
|[4,4]| 1|309237645312|
|[5,1]| 1|292057776128|
|[5,5]| 1|309237645312|
+-----+-----+------------+
Upvotes: 2