Havnar
Havnar

Reputation: 2628

Iterating over a grouped dataset in Spark 1.6

In an ordered dataset, I want to aggregate data until a condition is met, but grouped by a certain key.

To set some context to my question I simplify my problem to the below problem statement:

In spark I need to aggregate strings, grouped by key when a user stops "shouting" (the 2nd char in a string is not uppercase).

Dataset example:

ID, text, timestamps

1, "OMG I like bananas", 123
1, "Bananas are the best", 234
1, "MAN I love banana", 1235
2, "ORLY? I'm more into grapes", 123565
2, "BUT I like apples too", 999
2, "unless you count veggies", 9999
2, "THEN don't forget tomatoes", 999999

The expected result would be:

1, "OMG I like bananas Bananas are the best"
2, "ORLY? I'm more into grapes BUT I like apples too unless you count veggies"

via groupby and agg I can't seem to set a condition to "stop when an uppercase char" is found.

Upvotes: 0

Views: 1276

Answers (1)

Leonardo Herrera
Leonardo Herrera

Reputation: 8406

This only works in Spark 2.1 or above

What you want to do is possible, but it may be very expensive.

First, let's create some test data. As general advice, when you ask something on Stackoverflow please provide something similar to this so people have somewhere to start.

import spark.sqlContext.implicits._
import org.apache.spark.sql.functions._

val df = List(
    (1,  "OMG I like bananas", 1),
    (1, "Bananas are the best", 2),
    (1, "MAN I love banana", 3),
    (2, "ORLY? I'm more into grapes", 1),
    (2, "BUT I like apples too", 2),
    (2, "unless you count veggies", 3),
    (2, "THEN don't forget tomatoes", 4)
).toDF("ID", "text", "timestamps")

In order to get a column with the collected texts in order, we need to add a new column using a window function.

Using the spark shell:

scala> val df2 = df.withColumn("coll", collect_list("text").over(Window.partitionBy("id").orderBy("timestamps")))
df2: org.apache.spark.sql.DataFrame = [ID: int, text: string ... 2 more fields]

scala> val x = df2.groupBy("ID").agg(max($"coll").as("texts"))
x: org.apache.spark.sql.DataFrame = [ID: int, texts: array<string>]

scala> x.collect.foreach(println)
[1,WrappedArray(OMG I like bananas, Bananas are the best, MAN I love banana)]
[2,WrappedArray(ORLY? I'm more into grapes, BUT I like apples too, unless you count veggies, THEN don't forget tomatoes)]

To get the actual text we may need a UDF. Here's mine (I'm far from an expert in Scala, so bear with me)

import scala.collection.mutable

val aggText: Seq[String] => String = (list: Seq[String]) => {
    def tex(arr: Seq[String], accum: Seq[String]): Seq[String] = arr match {
        case Seq() => accum
        case Seq(single) => accum :+ single
        case Seq(str, xs @_*) => if (str.length >= 2 && !(str.charAt(0).isUpper && str.charAt(1).isUpper))
            tex(Nil, accum :+ str )
        else
            tex(xs, accum :+ str)
    }

    val res = tex(list, Seq())
    res.mkString(" ")
}

val textUDF = udf(aggText(_: mutable.WrappedArray[String]))

So, we have a dataframe with the collected texts in the proper order, and a Scala function (wrapped as a UDF). Let's piece it together:

scala> val x = df2.groupBy("ID").agg(max($"coll").as("texts"))
x: org.apache.spark.sql.DataFrame = [ID: int, texts: array<string>]

scala> val y = x.select($"ID", textUDF($"texts"))
y: org.apache.spark.sql.DataFrame = [ID: int, UDF(texts): string]

scala> y.collect.foreach(println)
[1,OMG I like bananas Bananas are the best]
[2,ORLY? I'm more into grapes BUT I like apples too unless you count veggies]

scala>

I think this is the result you want.

Upvotes: 2

Related Questions