Alexis Seigneurin
Alexis Seigneurin

Reputation: 1483

Transforming PySpark RDD with Scala

TL;DR - I have what looks like a DStream of Strings in a PySpark application. I want to send it as a DStream[String] to a Scala library. Strings are not converted by Py4j, though.

I'm working on a PySpark application that pulls data from Kafka using Spark Streaming. My messages are strings and I would like to call a method in Scala code, passing it a DStream[String] instance. However, I'm unable to receive proper JVM strings in the Scala code. It looks to me like the Python strings are not converted into Java strings but, instead, are serialized.

My question would be: how to get Java strings out of the DStream object?


Here is the simplest Python code I came up with:

from pyspark.streaming import StreamingContext
ssc = StreamingContext(sparkContext=sc, batchDuration=int(1))

from pyspark.streaming.kafka import KafkaUtils
stream = KafkaUtils.createDirectStream(ssc, ["IN"], {"metadata.broker.list": "localhost:9092"})
values = stream.map(lambda tuple: tuple[1])

ssc._jvm.com.seigneurin.MyPythonHelper.doSomething(values._jdstream)

ssc.start()

I'm running this code in PySpark, passing it the path to my JAR:

pyspark --driver-class-path ~/path/to/my/lib-0.1.1-SNAPSHOT.jar

On the Scala side, I have:

package com.seigneurin

import org.apache.spark.streaming.api.java.JavaDStream

object MyPythonHelper {
  def doSomething(jdstream: JavaDStream[String]) = {
    val dstream = jdstream.dstream
    dstream.foreachRDD(rdd => {
      rdd.foreach(println)
    })
  }
}

Now, let's say I send some data into Kafka:

echo 'foo bar' | $KAFKA_HOME/bin/kafka-console-producer.sh --broker-list localhost:9092 --topic IN

The println statement in the Scala code prints something that looks like:

[B@758aa4d9

I expected to get foo bar instead.

Now, if I replace the simple println statement in the Scala code with the following:

rdd.foreach(v => println(v.getClass.getCanonicalName))

I get:

java.lang.ClassCastException: [B cannot be cast to java.lang.String

This suggests that the strings are actually passed as arrays of bytes.

If I simply try to convert this array of bytes into a string (I know I'm not even specifying the encoding):

      def doSomething(jdstream: JavaDStream[Array[Byte]]) = {
        val dstream = jdstream.dstream
        dstream.foreachRDD(rdd => {
          rdd.foreach(bytes => println(new String(bytes)))
        })
      }

I get something that looks like (special characters might be stripped off):

�]qXfoo barqa.

This suggests the Python string was serialized (pickled?). How could I retrieve a proper Java string instead?

Upvotes: 5

Views: 2439

Answers (1)

zero323
zero323

Reputation: 330383

Long story short there is no supported way to do something like this. Don't try this in production. You've been warned.

In general Spark doesn't use Py4j for anything else than some basic RPC calls on the driver and doesn't start Py4j gateway on any other machine. When it is required (mostly MLlib and some parts of SQL) Spark uses Pyrolite to serialize objects passed between JVM and Python.

This part of the API is either private (Scala) or internal (Python) and as such not intended for general usage. While theoretically you access it anyway either per batch:

package dummy

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.streaming.api.java.JavaDStream
import org.apache.spark.sql.DataFrame

object PythonRDDHelper {
  def go(rdd: JavaRDD[Any]) = {
    rdd.rdd.collect {
      case s: String => s
    }.take(5).foreach(println)
  }
}

complete stream:

object PythonDStreamHelper {
  def go(stream: JavaDStream[Any]) = {
    stream.dstream.transform(_.collect {
      case s: String => s
    }).print
  }
}

or exposing individual batches as DataFrames (probably the least evil option):

object PythonDataFrameHelper {
  def go(df: DataFrame) = {
    df.show
  }
}

and use these wrappers as follows:

from pyspark.streaming import StreamingContext
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.rdd import RDD

ssc = StreamingContext(spark.sparkContext, 10)
spark.catalog.listTables()

q = ssc.queueStream([sc.parallelize(["foo", "bar"]) for _ in range(10)]) 

# Reserialize RDD as Java RDD<Object> and pass 
# to Scala sink (only for output)
q.foreachRDD(lambda rdd: ssc._jvm.dummy.PythonRDDHelper.go(
    _to_java_object_rdd(rdd)
))

# Reserialize and convert to JavaDStream<Object>
# This is the only option which allows further transformations
# on DStream
ssc._jvm.dummy.PythonDStreamHelper.go(
    q.transform(lambda rdd: RDD(  # Reserialize but keep as Python RDD
        _to_java_object_rdd(rdd), ssc.sparkContext
    ))._jdstream
)

# Convert to DataFrame and pass to Scala sink.
# Arguably there are relatively few moving parts here. 
q.foreachRDD(lambda rdd: 
    ssc._jvm.dummy.PythonDataFrameHelper.go(
        rdd.map(lambda x: (x, )).toDF()._jdf
    )
)

ssc.start()
ssc.awaitTerminationOrTimeout(30)
ssc.stop()

this is not supported, untested and as such rather useless for anything else than the experiments with Spark API.

Upvotes: 7

Related Questions