John Tan
John Tan

Reputation: 147

Spark scala - how to do count() by conditioning on two rows

I am newbie to spark scala and I apologize for asking silly question (if it is). I am stuck at a problem that I simplified as below:

There is a data frame with three columns, "machineID" is identity of a machine. "startTime" is the start time stamp of a task. "endTime" is the end time stamp of a task.

My goal is to count how many idle intervals each machine has.
For example,
in the table below, the 1st and 2nd rows show machine #1 started at time 0 and ended at time 3, and started again at time 4, so the time interval [3, 4] is idle. For the 3rd and 4th rows machine #1 started at time 10 and ended at time 20, and started again immediately, so there is no idle time.

machineID, startTime, endTime  
1, 0, 3  
1, 4, 8  
1, 10, 20  
1, 20, 31  
...  
1, 412, 578  
...  
2, 231, 311  
2, 781, 790  
...  

The data frame has already been groupBy("machineID").
I am using spark 2.0.1 and scala 2.11.8

Upvotes: 4

Views: 682

Answers (1)

maasg
maasg

Reputation: 37435

To access previous/next rows in a DataFrame we can use Window functions. In this case, we're going to use lag to access the previous ending time, grouped by machineId.

import org.apache.spark.sql.expressions.Window

// Dataframe Schema
case class MachineData(id:String, start:Int, end:Int)
// Sample Data
machineDF.show
+---+-----+---+
| id|start|end|
+---+-----+---+
|  1|    0|  3|
|  1|    4|  8|
|  1|   10| 20|
|  1|   20| 31|
|  1|  412|578|
|  2|  231|311|
|  2|  781|790|
+---+-----+---+


// define the window as a partition over machineId, ordered by start (time)
val byMachine = Window.partitionBy($"id").orderBy($"start")
// we define a new column, "previous end" using the Lag Window function over the previously defined window
val prevEnd = lag($"end", 1).over(byMachine)

// new DF with the prevEnd column
val withPrevEnd = machineDF.withColumn("prevEnd", prevEnd)
withPrevEnd.show

+---+-----+---+-------+
| id|start|end|prevEnd|
+---+-----+---+-------+
|  1|    0|  3|   null|
|  1|    4|  8|      3|
|  1|   10| 20|      8|
|  1|   20| 31|     20|
|  1|  412|578|     31|
|  2|  231|311|   null|
|  2|  781|790|    311|
+---+-----+---+-------+

// we're calculating the idle intervals as the numerical diff as an example
val idleIntervals = withPrevEnd.withColumn("diff", $"start"-$"prevEnd")
idleIntervals.show

+---+-----+---+-------+----+
| id|start|end|prevEnd|diff|
+---+-----+---+-------+----+
|  1|    0|  3|   null|null|
|  1|    4|  8|      3|   1|
|  1|   10| 20|      8|   2|
|  1|   20| 31|     20|   0|
|  1|  412|578|     31| 381|
|  2|  231|311|   null|null|
|  2|  781|790|    311| 470|
+---+-----+---+-------+----+

// to calculate the total, we are summing over the differences. Adapt this as your business logic requires.
val totalIdleIntervals = idleIntervals.select($"id",$"diff").groupBy($"id").agg(sum("diff"))

+---+---------+
| id|sum(diff)|
+---+---------+
|  1|      384|
|  2|      470|
+---+---------+

Upvotes: 4

Related Questions