Andrey
Andrey

Reputation: 9072

How to aggregate elements of one Akka stream based on elements of another?

Example scenario: group bytes of a stream into chunks of sizes determined by another stream (of integers).

def partition[A, B, C](
  first:Source[A, NotUsed],
  second:Source[B, NotUsed],
  aggregate:(Int => Seq[A], B) => C
):Source[C, NotUsed] = ???

val bytes:Source[Byte, NotUsed] = ???
val sizes:Source[Int, NotUsed] = ???

val chunks:Source[ByteString, NotUsed] =
  partition(bytes, sizes, (grab, count) => ByteString(grab(count)))

My initial attempt includes a combination of Flow#scan and Flow#prefixAndTail, but it doesn't feel quite right (see below). I also took a look at Framing, but it doesn't seem to be applicable to the example scenario above (nor is it general enough to accommodate non-bytestring streams). I'm guessing my only option is to use Graphs (or the more general FlowOps#transform), but I'm not nearly proficient enough (yet) with Akka streams to attempt that.


Here's what I was able to come up with so far (specific to the example scenario):

val chunks:Source[ByteString, NotUsed] = sizes
  .scan(bytes prefixAndTail 0) {
    (grouped, count) => grouped flatMapConcat {
      case (chunk, remainder) => remainder prefixAndTail count
    }
  }
  .flatMapConcat(identity)
  .collect { case (chunk, _) if chunk.nonEmpty => ByteString(chunk:_*) }

Upvotes: 3

Views: 2378

Answers (1)

lpiepiora
lpiepiora

Reputation: 13749

I think you can implement the processing as a custom GraphStage. The stage would have two Inlet elements. One taking the bytes, and the other taking the sizes. It would have one Outlet element producing the values.

Consider following input streams.

def randomChars = Iterator.continually(Random.nextPrintableChar())
def randomNumbers = Iterator.continually(math.abs(Random.nextInt() % 50))

val bytes: Source[Char, NotUsed] =
  Source.fromIterator(() => randomChars)

val sizes: Source[Int, NotUsed] =
  Source.fromIterator(() => randomNumbers).filter(_ != 0)

Then using information describing custom stream processing (http://doc.akka.io/docs/akka/2.4.2/scala/stream/stream-customize.html) you can construct the GraphStage.

case class ZipFraming() extends GraphStage[FanInShape2[Int, Char, (Int, ByteString)]] {

  override def initialAttributes = Attributes.name("ZipFraming")

  override val shape: FanInShape2[Int, Char, (Int, ByteString)] =
    new FanInShape2[Int, Char, (Int, ByteString)]("ZipFraming")

  val inFrameSize: Inlet[Int] = shape.in0
  val inElements: Inlet[Char] = shape.in1

  def out: Outlet[(Int, ByteString)] = shape.out

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
    new GraphStageLogic(shape) {
      // we will buffer as much as 512 characters from the input
      val MaxBufferSize = 512
      // the buffer for the received chars
      var buffer = Vector.empty[Char]
      // the needed number of elements
      var needed: Int = -1
      // if the downstream is waiting
      var isDemanding = false

      override def preStart(): Unit = {
        pull(inFrameSize)
        pull(inElements)
      }

      setHandler(inElements, new InHandler {
        override def onPush(): Unit = {
          // we buffer elements as long as we can
          if (buffer.size < MaxBufferSize) {
            buffer = buffer :+ grab(inElements)
            pull(inElements)
          }
          emit()
        }
      })

      setHandler(inFrameSize, new InHandler {
        override def onPush(): Unit = {
          needed = grab(inFrameSize)
          emit()
        }
      })

      setHandler(out, new OutHandler {
        override def onPull(): Unit = {
          isDemanding = true
          emit()
        }
      })

      def emit(): Unit = {
        if (needed > 0 && buffer.length >= needed && isDemanding) {
          val (emit, reminder) = buffer.splitAt(needed)
          push(out, (needed, ByteString(emit.map(_.toByte).toArray)))
          buffer = reminder
          needed = -1
          isDemanding = false
          pull(inFrameSize)
          if (!hasBeenPulled(inElements)) pull(inElements)
        }
      }
    }
}

And this is how you run it.

RunnableGraph.fromGraph(GraphDSL.create(bytes, sizes)(Keep.none) { implicit b =>
  (bs, ss) =>
    import GraphDSL.Implicits._

    val zipFraming = b.add(ZipFraming())

    ss ~> zipFraming.in0
    bs ~> zipFraming.in1

    zipFraming.out ~> Sink.foreach[(Int, ByteString)](e => println((e._1, e._2.utf8String)))

    ClosedShape
}).run()

Upvotes: 4

Related Questions