Reputation: 2050
I have dataframe in Spark. Looks like this:
| value| group| ts|
| A| X| 1|
| B| X| 2|
| B| X| 3|
| D| X| 4|
| E| X| 5|
| A| Y| 1|
| C| Y| 2|
Endgoal: I'd like to find how many sequences A-B-E
(a sequence is just a list of subsequent rows) there are. With the added constraint that subsequent parts of the sequence can be maximum n
rows apart. Let's consider for this example that n
is 2.
Consider group X
In this case there is exactly 1 D
between B
and E
(multiple consecutive B
s are ignored). Which means B
and E
are 1 row apart and thus there is a sequence A-B-E
I have thought about using collect_list()
, creating a string (like DNA) and using substring search with regex. But I was wondering if there's a more elegant distributed way, perhaps using window functions?
Note that the provided dataframe is just an example. The real dataframe (and thus groups) can be arbitrary long.
Upvotes: 9
Views: 26968
Reputation: 1538
Edited to answer @Tim's comment + fix patterns of the type "AABE"
Yep, using a window function helps, but I created an id
to have an ordering:
val df = List(
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy('group).orderBy('id)
Then lag will collect what is needed, but a function is required to generate the Column
expression (note the split to eliminate double counting of "AABE". WARNING: this rejects patterns of the type "ABAEXX"):
def createSeq(m:Int) = split(
(1 to 2*m)
.map(i => coalesce(lag('value,-i).over(w),lit("")))
val m=2
val tmp = df
| id|value|group| ts| seq|
| 6| A| Y| 1| C|
| 7| C| Y| 2| |
| 1| A| X| 1|BBDE|
| 2| B| X| 2| BDE|
| 3| B| X| 3| DE|
| 4| D| X| 4| E|
| 5| E| X| 5| |
Because of the poor set of collection functions available in the Column
API, avoiding regex altogether is much easier using a UDF
def patternInSeq(m: Int) = udf((str: String) => {
var notFound = str
.filter(_.indexOf("E") <= m)
val res = tmp
.filter(('value === "A") && (locate("B",'seq) > 0))
.filter(locate("B",'seq) <= m && (locate("E",'seq) > 1))
| X| 1|
If you want to generalise it sequence of letter that are longer, the question has to be generalised. It could be trivial, but in this case a pattern of the type ("ABAE") should be rejected (see comments). So the easiest way to generalise is to have a pair-wise rule as in the following implementation (I added a group "Z" to illustrate the behaviour of this algo)
val df = List(
( 8,"A","Z",1),
( 9,"B","Z",2),
First we define the logic for a pair
import org.apache.spark.sql.DataFrame
def createSeq(m:Int) = array((0 to 2*m).map(i => coalesce(lag('value,-i).over(w),lit(""))):_*)
def filterPairUdf(m: Int, t: (String,String)) = udf((ar: Array[String]) => {
val (a,b) = t
val foundAt = ar
.dropWhile(_ != a)
.takeWhile(_ != a)
foundAt != -1 && foundAt <= m
Then we define a function that applies this logic is applied iteratively on the dataframe
def filterSeq(seq: List[String], m: Int)(df: DataFrame): DataFrame = {
var a = seq(0)
seq.tail.foldLeft(df){(df: DataFrame, b: String) => {
val res = df.filter(filterPairUdf(m,(a,b))('seq))
a = b
A simplification and optimisation is obtained because we first filter on sequence beginning with the first character
val m = 2
val tmp = df
.filter('value === "A") // reduce problem
| id|value|group| ts| seq|
| 6| A| Y| 1| [A, C, , , ]|
| 8| A| Z| 1|[A, B, D, B, E]|
| 1| A| X| 1|[A, B, B, D, E]|
val res = tmp.transform(filterSeq(List("A","B","E"),m))
| id|value|group| ts| seq|
| 1| A| X| 1|[A, B, B, D, E]|
is a simple sugar-coating of DataFrame => DataFrame
| X| 1|
As I said, there are different way to generalise the "resetting rules" when scanning a sequence,but this exemple hopefully helps in the implementation of more complex ones.
Upvotes: 8