Reputation: 151
The best way to describe the problem is to give you an example of an input and what I want to get out as output.
Input
--------------------
|id|timestamp |count|
| 1|2017-06-22| 1 |
| 1|2017-06-23| 0 |
| 1|2017-06-24| 1 |
| 2|2017-06-22| 0 |
| 2|2017-06-23| 1 |
The logic would be something like, if (the total number of 1
s in count is equal or higher than Y
for the last X
days)
code = True
else
code = False
Let's say X = 5
and Y = 2
then the output should look like
Output
---------------------
id | code |
1 | True |
2 | False |
The input is a SparkSQL
dataframe
(org.apache.spark.sql.DataFrame
)
Doesn't sound like a very complex problem, but I am stuck on the first step. I only have managed to load the data in a dataframe
!
Any ideas?
Upvotes: 0
Views: 40
Reputation: 41957
Looking at your requirement, UDAF
aggregation
suits the best. You can checkout databricks and ragrawal for better understanding.
I am providing you guidance according to what I understood and I hope it is helpful
First of all you need to define UDAF
. You would be able to do it after you successfully read the above links.
private class ManosAggregateFunction(daysToCheck: Int, countsToCheck: Int) extends UserDefinedAggregateFunction {
var referenceDate: String = _
def inputSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType)
// the aggregation buffer can also have multiple values in general but
// this one just has one: the partial sum
def bufferSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType).add("days", IntegerType)
// returns just a double: the sum
def dataType: DataType = BooleanType
// always gets the same result
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, "")
buffer.update(1, 0)
buffer.update(2, 0)
referenceDate = ""
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val nowDate = input.getString(0)
val count = input.getInt(1)
buffer.update(0, nowDate)
buffer.update(1, count)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
val previousDate = buffer1.getString(0)
val nowDate = buffer2.getString(0)
if(previousDate != "") {
val oldDate = LocalDate.parse(previousDate, formatter)
val newDate = LocalDate.parse(nowDate, formatter)
buffer1.update(2, buffer1.getInt(2)+(oldDate.toEpochDay() - newDate.toEpochDay()).toInt)
}
buffer1.update(0, buffer2.getString(0))
if(buffer1.getInt(2) < daysToCheck) {
buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))
}
}
def evaluate(buffer: Row): Any = {
countsToCheck <= buffer.getInt(1)
}
}
In the above UDAF
, daysToCheck
and countsToCheck
are the X
and Y
in your question.
You can call the defined UDAF
as below
val manosAgg = new ManosAggregateFunction(5,2)
df.orderBy($"timestamp".desc).groupBy("id").agg(manosAgg(col("timestamp"), col("count")).as("code")).show
Final output is
+---+-----+
| id| code|
+---+-----+
| 1| true|
| 2|false|
+---+-----+
Given input
val df = Seq(
(1, "2017-06-22", 1),
(1, "2017-06-23", 0),
(1, "2017-06-24", 1),
(2, "2017-06-28", 0),
(2, "2017-06-29", 1)
).toDF("id","timestamp","count")
+---+----------+-----+
|id |timestamp |count|
+---+----------+-----+
|1 |2017-06-22|1 |
|1 |2017-06-23|0 |
|1 |2017-06-24|1 |
|2 |2017-06-28|0 |
|2 |2017-06-29|1 |
+---+----------+-----+
I hope you have got the idea for your problem. :)
Upvotes: 1