Konstantin Kulagin
Konstantin Kulagin

Reputation: 724

Spark mapPartitions closure behavior

This is a pretty common spark related question regarding handling situation which piece of code is executed on which park of spark (executor/driver). Having this piece of code I am a bit surprised why I do not get values I am expecting:

1    stream
2      .foreachRDD((kafkaRdd: RDD[ConsumerRecord[String, String]]) => {
3        val offsetRanges = kafkaRdd.asInstanceOf[HasOffsetRanges].offsetRanges
4        import argonaut.Argonaut.StringToParseWrap
5
6        val rdd: RDD[SimpleData] = kafkaRdd.mapPartitions((records: Iterator[ConsumerRecord[String, String]]) => {
7          val invalidCount: AtomicLong = new AtomicLong(0)
8          val convertedData: Iterator[SimpleData] = records.map(record => {
9            val maybeData: Option[SimpleData] = record.value().decodeOption[SimpleData]
10           if (maybeData.isEmpty) {
11             logger.error("Cannot parse data from kafka: " + record.value())
12             invalidCount.incrementAndGet()
13           }
14           maybeData
15         })
16           .filter(_.isDefined)
17           .map(_.get)
18
19         val statsDClient = new NonBlockingStatsDClient("appName", "monitoring.host", 8125) // I know it should be a singleton :)
20         statsDClient.gauge("invalid-input-records", invalidCount.get())
21
22         convertedData
23       })
24
25       rdd.collect().length
26       stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges)
27     })

Idea: getting JSON data from kafka report number entries that have invalid format (if any). I assume that when I am using mapPartitions method code inside will be executed for each partition I have. I.e. I would expect that lines 7-22 will be wrapped/closure-d and sent to executor for execution. In this case I was expecting that

invalidData

variable will be in scope of execution on executor and will be updated if there is an error happened during json->object conversion (lines 10-13). Because internally there is no notion of RDD or something - there is only regular scala iterator over regular entries. In lines 19-20 statsd client sends to metric server invalidData value. Apparently I am always getting '0' as a result.

However if I change code to this:

1     stream
2       .foreachRDD((kafkaRdd: RDD[ConsumerRecord[String, String]]) => {
3         val offsetRanges = kafkaRdd.asInstanceOf[HasOffsetRanges].offsetRanges
4
5         // this is ugly we have to repeat it - but argonaut is NOT serializable...
6         val rdd: RDD[SimpleData] = kafkaRdd.mapPartitions((records: Iterator[ConsumerRecord[String, String]]) => {
7           import argonaut.Argonaut.StringToParseWrap
8            val convertedDataTest: Iterator[(Option[SimpleData], String)] = records.map(record => {
9             val maybeData: Option[SimpleData] = record.value().decodeOption[SimpleData]
10            (maybeData, record.value())
11          })
12
13          val testInvalidDataEntries: Int = convertedDataTest.count(record => {
14            val empty = record._1.isEmpty
15            if (empty) {
16              logger.error("Cannot parse data from kafka: " + record._2)
17            }
18            empty
19          })
20          val statsDClient = new NonBlockingStatsDClient("appName", "monitoring.host", 8125) // I know it should be a singleton :)
21          statsDClient.gauge("invalid-input-records", testInvalidDataEntries)
22
23          convertedDataTest
24            .filter(maybeData => maybeData._1.isDefined)
25            .map(data => data._1.get)
26        })
27
28        rdd.collect().length
29        stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges)
30      })

It works as expected. I.e. if I count invalid entries implicitly I am getting expecting value.

Not sure I am getting why. Ideas?

Code to play with can be found at github

Upvotes: 1

Views: 1274

Answers (1)

GPI
GPI

Reputation: 9338

The reason is actually pretty simple and not at all Spark related.

See this Scala console sample, that does not involve Spark at all :

scala> val iterator: Iterator[String] = Seq("a", "b", "c").iterator
    iterator: Iterator[String] = non-empty iterator

scala> val count = new java.util.concurrent.atomic.AtomicInteger(0)
    count: java.util.concurrent.atomic.AtomicInteger = 0

scala> val mappedIterator = iterator.map(letter => {print("mapping!! "); count.incrementAndGet(); letter})
    mappedIterator: Iterator[String] = non-empty iterator

scala> count.get
    res3: Int = 0

See how I start with an iterator and a fresh counter, I map on this iterator but nothing happened : the println did not show, and the count is still zero.

But when I materialize the mappedIterator's content :

scala> mappedIterator.next
    mapping!! res1: String = a

Now something happened, and I got a print, and an augmented counter.

scala> count.get
    res2: Int = 1

The same happens in your code on the spark executor.

It is because Scala iterators are lazy with respect to a map operation. (See also here and here)

So, in your first sample what chronogically happens is :

  1. You define a transformation on the orignal partition iterator (but you do not execute the transformation itself)
  2. You push your counter variable, which is at its initial state (because no transformation has occurend)
  3. You pass Spark the transformation-ready iterator
  4. Spark actually iterates over this result, and so the mapping occurs. But your side-effect (step 2) has already been performed.

In the second scenario, you call val testInvalidDataEntries: Int = convertedDataTest.count... which performs the actual mapping (and in the process, the incrementation of the counter), before sending your counter to your server.

So, it is the lazyness that makes your 2 samples perform differently.

(That is also why, generally and theoretically speaking, we tend not to side-effect in map operations in functionnal programming-oriented languages, because the result becomes dependant on the order of execution, whereas pure functionnal style should prevent this).

One way to count failures could be the use of a Spark Accumulator to accumulate results and perform your update on the driver side after the RDD has gone through a terminal operation.

Upvotes: 3

Related Questions