gjin
gjin

Reputation: 929

Merging rows in a dataset

Input data

val df = Seq(
  ("1", 1, 1, "on"),
  ("1", 2, 2, "off"),
  ("1", 2, 5, "off"),
  ("1", 5, 5, "on"),
  ("1", 5, 6, "off"),
  ("2", 1, 1, "off"),
  ("2", 1, 2, "off"),
  ("2", 2, 2, "on"),
  ("2", 3, 4, "off"),
  ("2", 5, 7, "off"),
  ("2", 8, 10, "on"),
  ("2", 11, 11, "on"),
  ("2", 11, 12, "off"),
  ("3", 1, 12, "off")
).toDF("id", "start", "end", "sw")

I'm trying to merge rows using groupBy and mapGroups.

Desired output

1 1 5 on
1 5 6 on
2 1 2 off
2 2 7 on
2 8 10 on
2 11 12 on
3 1 12 off

The logic is as follows. Every off row is merged into the previous on row. If the first or only value is off, I get a single off row. Start is from first row, end from last row. The data shoud be sorted by start and end.

Here is what I have so far

df
  .as[Row]
  .groupByKey(_.id)
  .mapGroups{case(k, iter) => Row.merge(iter)}

I group the data by id and then try to iterate other the values.

case class Row(id:String, start:Int, var end:Int, sw:String)

object Row {
  def merge(iter: Iterator[Row]): ListBuffer[Row] = {
    val listBuffer = new ListBuffer[Row]
    var bufferRow = Row("", 0, 0, "")
    for(row <- iter){
      if(listBuffer.isEmpty) bufferRow = row
      else if(row.sw == "off") bufferRow.end = row.end
      else if(row.sw == "on") {
        listBuffer += bufferRow
        bufferRow = row
      }
    }
    if(listBuffer.isEmpty) listBuffer += bufferRow
    listBuffer
  }
}

My output

[WrappedArray([1,5,6,off])]
[WrappedArray([2,11,12,off])]
[WrappedArray([3,1,12,off])]

I already accomplished something similar using window functions and cumulative sum. Here I'm trying to learn a new approach.

Using spark 2.2, scala 2.11.

Upvotes: 0

Views: 172

Answers (1)

Vincent Doba
Vincent Doba

Reputation: 5068

Your proposed solution is almost right, just need some adjustments

First, as your method Row.merge returns a list of Rows instead of a Row, you should use flatMapGroups to explode your list to different records in your Dataset:

df
  .as[Row]
  .groupByKey(_.id)
  .flatMapGroups { case (k, iter) => Row.merge(iter) }

Next, let's dive into your Row.merge method.

You create an empty bufferRow that you populate at first iteration of your loop over iter with statement if (listBuffer.isEmpty) bufferRow = row. However, condition in this statement is true for all iterations, that's why your output contains only the latest rows of each group. So this statement should be removed. To initialize your bufferRow, you can just call iter.next():

... = new ListBuffer[Row]
var bufferRow = iter.next()
for(row <- iter) { ...

As iterator comes from a groupBy, it contains at least one element, thus first call of iter.next() doesn't throw exception. And as the call to .next() method remove the first element of the iterator, the loop after doesn't reprocess this first element.

Next, the last statement of your method before return statement is if (listBuffer.isEmpty) listBuffer += bufferRow. There should not be a condition for this statement.

Indeed, in your loop, you populate bufferRow then you add it to the listBuffer only when the currently processed row has "sw" field set to "on". And this currently processed row becomes the new bufferRow. It means that the last bufferRow is never saved in listBuffer, except when listBuffer is empty. So the last lines of your merge method should be:

...
    bufferRow = row
  }
}
listBuffer += bufferRow

We now have the complete merge method:

def merge(iter: Iterator[Row]): ListBuffer[Row] = {
  val listBuffer = new ListBuffer[Row]
  var bufferRow = iter.next()
  for (row <- iter) {
    if (row.sw == "off") bufferRow.end = row.end
    else if (row.sw == "on" ) {
      listBuffer += bufferRow
      bufferRow = row
    }
  }
  listBuffer += bufferRow
}

And running this code gives you the following result, reordered by id and start columns:

+---+-----+---+---+
|id |start|end|sw |
+---+-----+---+---+
|1  |1    |5  |on |
|1  |5    |6  |on |
|2  |1    |2  |off|
|2  |2    |7  |on |
|2  |8    |10 |on |
|2  |11   |12 |on |
|3  |1    |12 |off|
+---+-----+---+---+

Final Note: you should be careful of iterator ordering if you run this code on partitioned datasets, I'm not sure that the Spark's groupBy method keeps ordering of rows in grouped by iterator. Maybe it is wiser to reorder iterator with .toSeq.sortBy(...) before itering over it.

Upvotes: 1

Related Questions