Manos
Manos

Reputation: 151

Categorise data after counting number of occurances in Spark

The best way to describe the problem is to give you an example of an input and what I want to get out as output.

Input

--------------------
|id|timestamp |count|
| 1|2017-06-22|  1  | 
| 1|2017-06-23|  0  |
| 1|2017-06-24|  1  |
| 2|2017-06-22|  0  |
| 2|2017-06-23|  1  |

The logic would be something like, if (the total number of 1s in count is equal or higher than Y for the last X days)

code = True 

else

code = False 

Let's say X = 5 and Y = 2 then the output should look like

Output

---------------------
id | code  | 
 1 | True  |
 2 | False |

The input is a SparkSQL dataframe(org.apache.spark.sql.DataFrame)

Doesn't sound like a very complex problem, but I am stuck on the first step. I only have managed to load the data in a dataframe!

Any ideas?

Upvotes: 0

Views: 40

Answers (1)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

Looking at your requirement, UDAF aggregation suits the best. You can checkout databricks and ragrawal for better understanding.

I am providing you guidance according to what I understood and I hope it is helpful

First of all you need to define UDAF. You would be able to do it after you successfully read the above links.

private class ManosAggregateFunction(daysToCheck: Int, countsToCheck: Int) extends UserDefinedAggregateFunction {

  var referenceDate: String = _
  def inputSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType)
  // the aggregation buffer can also have multiple values in general but
  // this one just has one: the partial sum
  def bufferSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType).add("days", IntegerType)
  // returns just a double: the sum
  def dataType: DataType = BooleanType
  // always gets the same result
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, "")
    buffer.update(1, 0)
    buffer.update(2, 0)
    referenceDate = ""
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val nowDate = input.getString(0)
    val count = input.getInt(1)

    buffer.update(0, nowDate)
    buffer.update(1, count)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
    val previousDate = buffer1.getString(0)
    val nowDate = buffer2.getString(0)
    if(previousDate != "") {
      val oldDate = LocalDate.parse(previousDate, formatter)
      val newDate = LocalDate.parse(nowDate, formatter)
      buffer1.update(2, buffer1.getInt(2)+(oldDate.toEpochDay() - newDate.toEpochDay()).toInt)
    }
    buffer1.update(0, buffer2.getString(0))
    if(buffer1.getInt(2) < daysToCheck) {
      buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))
    }
  }

  def evaluate(buffer: Row): Any = {
    countsToCheck <= buffer.getInt(1)
  }
}

In the above UDAF, daysToCheck and countsToCheck are the X and Y in your question.

You can call the defined UDAF as below

    val manosAgg = new ManosAggregateFunction(5,2)
    df.orderBy($"timestamp".desc).groupBy("id").agg(manosAgg(col("timestamp"), col("count")).as("code")).show

Final output is

+---+-----+
| id| code|
+---+-----+
|  1| true|
|  2|false|
+---+-----+

Given input

val df = Seq(
  (1, "2017-06-22", 1),
  (1, "2017-06-23", 0),
  (1, "2017-06-24", 1),
  (2, "2017-06-28", 0),
  (2, "2017-06-29", 1)
).toDF("id","timestamp","count")
+---+----------+-----+
|id |timestamp |count|
+---+----------+-----+
|1  |2017-06-22|1    |
|1  |2017-06-23|0    |
|1  |2017-06-24|1    |
|2  |2017-06-28|0    |
|2  |2017-06-29|1    |
+---+----------+-----+

I hope you have got the idea for your problem. :)

Upvotes: 1

Related Questions