GSanta
GSanta

Reputation: 43

Spark RDD Or SQL operations to compute conditional counts

As a bit of background, I'm trying to implement the Kaplan-Meier in Spark. In particular, I assume I have a data frame/set with a Double column denoted as Data and an Int column named censorFlag (0 value if censored, 1 if not, prefer this over Boolean type).

Example:

val df = Seq((1.0, 1), (2.3, 0), (4.5, 1), (0.8, 1), (0.7, 0), (4.0, 1), (0.8, 1)).toDF("data", "censorFlag").as[(Double, Int)] 

Now I need to compute a column wins that counts instances of each data value. I achieve that with the following code:

val distDF = df.withColumn("wins", sum(col("censorFlag")).over(Window.partitionBy("data").orderBy("data")))

The problem comes when I need to compute a quantity called atRisk which counts, for each value of data, the number of data points that are greater than or equal to it (a cumulative filtered count, if you will).

The following code works:

// We perform the counts per value of "bins". This is an array of doubles
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins").as[Double].collect 
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// this works:
atRiskCounts.show

However, the use case involves deriving bins from the column data itself, which I'd rather leave as a single column data set (or RDD at worst), but certainly not local array. But this doesn't work:

// Here, 'bins' rightfully come from the data itself.
val bins = df.select(col("data").as("dataBins")).distinct().as[Double]
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// This doesn't work -- NullPointerException
atRiskCounts.show

Nor does this:

// Manually creating the bins and then parallelizing them.
val bins = Seq(0.7, 0.8, 1.0, 3.0).toDS
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk")
// Also fails with a NullPointerException
atRiskCounts.show

Another approach that does work, but is also not satisfactory from a parallelization perspective is using Window:

// Do the counts in one fell swoop using a giant window per value.
val atRiskCounts = df.withColumn("atRisk", count("censorFlag").over(Window.orderBy("data").rowsBetween(0, Window.unboundedFollowing))).groupBy("data").agg(first("atRisk").as("atRisk"))
// Works, BUT, we get a "WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation." 
atRiskCounts.show

This last solution isn't useful as it ends up shuffling my data to a single partition (and in that case, I might as well go with Option 1 tha works).

The successful approaches are fine except that the bins are not parallel, which is something I'd really like to keep if possible. I've looked at groupBy aggregations, pivot type of aggregations, but none seem to make sense.

My question is: is there any way to compute atRisk column in a distributed way? Also, why do I get a NullPointerException in the failed solutions?

EDIT PER COMMENT:

I didn't originally post the NullPointerException as it didn't seem to include anything useful. I'll make a note that this is Spark installed via homebrew on my Macbook Pro (Spark version 2.2.1, standalone localhost mode).

                18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.package on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/package.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.scala on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/scala.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.org on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/org.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.java on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/java.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR Executor: Exception in task 0.0 in stage 55.0 (TID 432)
            java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
                at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
                at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
                at java.lang.Thread.run(Thread.java:748)
            18/03/12 11:41:00 WARN TaskSetManager: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            18/03/12 11:41:00 ERROR TaskSetManager: Task 0 in stage 55.0 failed 1 times; aborting job
            org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 55.0 failed 1 times, most recent failure: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $anonfun$1.apply(<console>:33)
                at $anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            Driver stacktrace:
              at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1517)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1505)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1504)
              at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
              ... 50 elided
            Caused by: java.lang.NullPointerException
              at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
              at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
              at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
              at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
              at $anonfun$1.apply(<console>:33)
              at $anonfun$1.apply(<console>:33)
              at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
              at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
              at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
              at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
              at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
              at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
              at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
              at org.apache.spark.scheduler.Task.run(Task.scala:108)
              at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
              at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
              at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
              at java.lang.Thread.run(Thread.java:748)

My best guess is that the line df("data").geq(x).count might be the part that barfs as not every node may have x and thus a null pointer?

Upvotes: 3

Views: 741

Answers (3)

GSanta
GSanta

Reputation: 43

I tried the above examples (albeit not the most rigorously!), and it seems the left join works best in general.

The data:

import org.apache.spark.mllib.random.RandomRDDs._
val df = logNormalRDD(sc, 1, 3.0, 10000, 100).zip(uniformRDD(sc, 10000, 100).map(x => if(x <= 0.4) 1 else 0)).toDF("data", "censorFlag").withColumn("data", round(col("data"), 2))

The join example:

def runJoin(sc: SparkContext, df:DataFrame): Unit = {
  val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins")
  val wins = df.groupBy(col("data")).agg(sum("censorFlag").as("wins"))
  val atRiskCounts = bins.join(df, bins("dataBins") <= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")
  val finalDF = wins.join(atRiskCounts, wins("data") === atRiskCounts("dataBins")).select("data", "wins", "atRisk").sort("data")
  finalDF.show
}

The broadcast example:

def runBroadcast(sc: SparkContext, df: DataFrame): Unit = {
  val bins = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
  val binsBroadcast = sc.broadcast(bins)
  val df2 = binsBroadcast.value.map(x => (x, df.filter(col("data").geq(x)).select(count(col("data"))).as[Long].first)).toDF("data", "atRisk")
  val finalDF = df.groupBy(col("data")).agg(sum("censorFlag").as("wins")).join(df2, "data")
  finalDF.show
  binsBroadcast.destroy
}

And the testing code:

var start = System.nanoTime()
runJoin(sc, sampleDF)
val joinTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

start = System.nanoTime()
runBroadcast(sc, sampleDF)
val broadTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

I ran this code for different sizes of the random data, provided manual bins arrays (some very granular, 50% of original distinct data, some very small, 10% of original distinct data), and consistently it seems the join approach is the fastest (although both arrive at the same solution, so that is a plus!).

On average I find that the smaller the bin array, the better broadcast approach works, but join doesn't seem too affected. If I had more time/resource to test this, I'd run lots of simulations to see what the average run time looks like, but for now I'll accept @hoyland's solution.

Still have not sure why the original approach didn't work, so open to comments on that.

Kindly let me know of any issues in my code, or improvements! Thank you both :)

Upvotes: 1

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

When there is a requirement as yours to check a value in a column with all the rest of the values in that column, collection is the most important. And when there is requirement to check with all the values then it is certain that all the data of that column need to be accumulated in one executor or driver. You cannot avoid the step when there is a requirement as yours.

Now the main part is how you define the rest of the steps to benefit from the parallelization of spark. I would suggest you to broadcast the collected set (as its distinct data of one column only so they must not be huge) and use a udf function for checking the gte condition as below

firstly you can optimize the collection step of yours as

import org.apache.spark.sql.functions._
val collectedData = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]

Then you broadcast the collected set

val broadcastedArray = sc.broadcast(collectedData)

Next step is to define a udf function and check the gte condition and return counts

def checkingUdf = udf((data: Double)=> broadcastedArray.value.count(x => x >= data))

and use it as

distDF.withColumn("atRisk", checkingUdf(col("data"))).show(false)

So that finally you should have

+----+----------+----+------+
|data|censorFlag|wins|atRisk|
+----+----------+----+------+
|4.5 |1         |1   |1     |
|0.7 |0         |0   |6     |
|2.3 |0         |0   |3     |
|1.0 |1         |1   |4     |
|0.8 |1         |2   |5     |
|0.8 |1         |2   |5     |
|4.0 |1         |1   |2     |
+----+----------+----+------+

I hope thats the required dataframe

Upvotes: 1

hoyland
hoyland

Reputation: 1824

I have not tested this so the syntax may be goofy, but I would do a series of joins:

I believe your first statement is equivalent to this--for each data value, count how many wins there are:

val distDF = df.groupBy($"data").agg(sum($"censorFlag").as("wins"))

Then, as you noted, we can build a dataframe of the bins:

val distinctData = df.select($"data".as("dataBins")).distinct()

And then join with a >= condition:

val atRiskCounts = distDF.join(distinctData, distDF.data >= distinctData.dataBins)
  .groupBy($"data", $"wins")
  .count()

Upvotes: 1

Related Questions