Reputation: 121
I'm working with a Spark DataFrame like the one below:
user | 1 | 2 | 3 | 4 | ... | 53
-------------------------------
1 | 1 | 0 | 0 | 1 | ... | 1
2 | 0 | 1 | 1 | 1 | ... | 0
3 | 1 | 1 | 0 | 0 | ... | 1
.
.
.
n | 1 | 0 | 1 | 1 | ... | 0
which has columns representing user ID, then a column for each week of the year containing a boolean, representing whether a user was active for that week.
My goal is to reduce this to a table like so:
user | active_start | active_end | duration
-------------------------------------------
1 | 1 | 1 | 1
1 | 4 | 4 | 1
1 | 53 | 53 | 1
2 | 2 | 4 | 3
3 | 1 | 2 | 2
3 | 53 | 53 | 1
.
.
.
n | 1 | 1 | 1
n | 3 | 4 | 2
which contains periods of continuous activity.
I'm at somewhat of a loss as to how I should manipulate the table/aggregate the values so as to create a new row when a gap is detected.
I have tried using code for Island/Gap detection to generate these groups, but have been unable to implement a version which does not detect and generate rows for smaller sub-islands within larger ones.
Any help would be appreciated, Thanks!
Upvotes: 0
Views: 179
Reputation: 13985
Just flatMap
your df with a function to calculate your metrics for every row.
Then provide the column names to your new DF.
val newDf = yourDf
.flatMap(row => {
val userId = row.getInt(0)
val arrayBuffer = ArrayBuffer[(Int, Int, Int, Int)]()
var start = -1
for (i <- 1 to 53) {
val active = row.getInt(i)
if (active == 1 && start == -1) {
start = i
}
else if (active == 0 && start != -1) {
val duration = i - start + 1
val end = i - 1
arrayBuffer.append((userId, start, end, duration))
start = -1
}
}
arrayBuffer
})
.toDF("user", "active_start", "active_end", "duration" )
Upvotes: 0
Reputation: 17872
Here's another suggestion, also using flatMap
, but with foldLeft
inside to compute the intervals:
case class Interval(user: Int, active_start: Int, active_end: Int, duration: Int)
def computeIntervals(userId: Int, weeks: Seq[Int]): TraversableOnce[Interval] = {
// First, we get the indexes where the value is 1
val indexes: Seq[Int] = weeks.zipWithIndex.collect {
case (value, index) if value == 1 => index
}
// Then, we find the "breaks" in the sequence (i.e. when the difference between indexes is > 1)
val breaks: Seq[Int] = indexes.foldLeft((List[Int](), -1)) { (pair, currentValue) =>
val (breaksBuffer: List[Int], lastValue: Int) = pair
if ((currentValue - lastValue) > 1 && lastValue >= 0) (breaksBuffer :+ lastValue :+ currentValue, currentValue)
else (breaksBuffer, currentValue)
}._1
// Then, we add the first and last indexes and re-organize in pairs
val breakPairs = (indexes.head +: breaks :+ indexes.last).map(_ + 1).grouped(2)
// Finally, we convert each pair to an interval and return
breakPairs.map {
case List(lower, upper) => Interval(userId, lower, upper, upper-lower+1)
}
}
Running:
val df = Seq(
(1, 1, 0, 0, 1, 1),
(2, 0, 1, 1, 1, 0),
(3, 0, 0, 1, 0, 1),
(4, 1, 1, 0, 0, 1)
).toDF
import spark.implicits._
df.flatMap { row: Row =>
val (userId, weeksAsSeq) = ((row.toSeq.head.asInstanceOf[Int], row.toSeq.drop(1).map(_.asInstanceOf[Int])))
computeIntervals(userId, weeksAsSeq)
}.show
+----+------------+----------+--------+
|user|active_start|active_end|duration|
+----+------------+----------+--------+
| 1| 1| 1| 1|
| 1| 4| 5| 2|
| 2| 2| 4| 3|
| 3| 3| 3| 1|
| 3| 5| 5| 1|
| 4| 1| 2| 2|
| 4| 5| 5| 1|
+----+------------+----------+--------+
Upvotes: 0