Reputation: 43
As a bit of background, I'm trying to implement the Kaplan-Meier in Spark. In particular, I assume I have a data frame/set with a Double
column denoted as Data
and an Int
column named censorFlag
(0
value if censored, 1
if not, prefer this over Boolean
type).
Example:
val df = Seq((1.0, 1), (2.3, 0), (4.5, 1), (0.8, 1), (0.7, 0), (4.0, 1), (0.8, 1)).toDF("data", "censorFlag").as[(Double, Int)]
Now I need to compute a column wins
that counts instances of each data
value. I achieve that with the following code:
val distDF = df.withColumn("wins", sum(col("censorFlag")).over(Window.partitionBy("data").orderBy("data")))
The problem comes when I need to compute a quantity called atRisk
which counts, for each value of data
, the number of data
points that are greater than or equal to it (a cumulative filtered count, if you will).
The following code works:
// We perform the counts per value of "bins". This is an array of doubles
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins").as[Double].collect
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// this works:
atRiskCounts.show
However, the use case involves deriving bins
from the column data
itself, which I'd rather leave as a single column data set (or RDD at worst), but certainly not local array. But this doesn't work:
// Here, 'bins' rightfully come from the data itself.
val bins = df.select(col("data").as("dataBins")).distinct().as[Double]
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// This doesn't work -- NullPointerException
atRiskCounts.show
Nor does this:
// Manually creating the bins and then parallelizing them.
val bins = Seq(0.7, 0.8, 1.0, 3.0).toDS
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk")
// Also fails with a NullPointerException
atRiskCounts.show
Another approach that does work, but is also not satisfactory from a parallelization perspective is using Window
:
// Do the counts in one fell swoop using a giant window per value.
val atRiskCounts = df.withColumn("atRisk", count("censorFlag").over(Window.orderBy("data").rowsBetween(0, Window.unboundedFollowing))).groupBy("data").agg(first("atRisk").as("atRisk"))
// Works, BUT, we get a "WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation."
atRiskCounts.show
This last solution isn't useful as it ends up shuffling my data to a single partition (and in that case, I might as well go with Option 1 tha works).
The successful approaches are fine except that the bins are not parallel, which is something I'd really like to keep if possible. I've looked at groupBy
aggregations, pivot
type of aggregations, but none seem to make sense.
My question is: is there any way to compute atRisk
column in a distributed way? Also, why do I get a NullPointerException
in the failed solutions?
EDIT PER COMMENT:
I didn't originally post the NullPointerException
as it didn't seem to include anything useful. I'll make a note that this is Spark installed via homebrew on my Macbook Pro (Spark version 2.2.1, standalone localhost mode).
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.package on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/package.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . . .
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.scala on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/scala.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . .
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.org on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/org.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . .
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.java on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/java.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . .
18/03/12 11:41:00 ERROR Executor: Exception in task 0.0 in stage 55.0 (TID 432)
java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:748)
18/03/12 11:41:00 WARN TaskSetManager: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
18/03/12 11:41:00 ERROR TaskSetManager: Task 0 in stage 55.0 failed 1 times; aborting job
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 55.0 failed 1 times, most recent failure: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $anonfun$1.apply(<console>:33)
at $anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1517)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1505)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1504)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
... 50 elided
Caused by: java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $anonfun$1.apply(<console>:33)
at $anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:748)
My best guess is that the line df("data").geq(x).count
might be the part that barfs as not every node may have x
and thus a null pointer?
Upvotes: 3
Views: 741
Reputation: 43
I tried the above examples (albeit not the most rigorously!), and it seems the left join
works best in general.
The data:
import org.apache.spark.mllib.random.RandomRDDs._
val df = logNormalRDD(sc, 1, 3.0, 10000, 100).zip(uniformRDD(sc, 10000, 100).map(x => if(x <= 0.4) 1 else 0)).toDF("data", "censorFlag").withColumn("data", round(col("data"), 2))
The join example:
def runJoin(sc: SparkContext, df:DataFrame): Unit = {
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins")
val wins = df.groupBy(col("data")).agg(sum("censorFlag").as("wins"))
val atRiskCounts = bins.join(df, bins("dataBins") <= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")
val finalDF = wins.join(atRiskCounts, wins("data") === atRiskCounts("dataBins")).select("data", "wins", "atRisk").sort("data")
finalDF.show
}
The broadcast example:
def runBroadcast(sc: SparkContext, df: DataFrame): Unit = {
val bins = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
val binsBroadcast = sc.broadcast(bins)
val df2 = binsBroadcast.value.map(x => (x, df.filter(col("data").geq(x)).select(count(col("data"))).as[Long].first)).toDF("data", "atRisk")
val finalDF = df.groupBy(col("data")).agg(sum("censorFlag").as("wins")).join(df2, "data")
finalDF.show
binsBroadcast.destroy
}
And the testing code:
var start = System.nanoTime()
runJoin(sc, sampleDF)
val joinTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)
start = System.nanoTime()
runBroadcast(sc, sampleDF)
val broadTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)
I ran this code for different sizes of the random data, provided manual bins
arrays (some very granular, 50% of original distinct data, some very small, 10% of original distinct data), and consistently it seems the join
approach is the fastest (although both arrive at the same solution, so that is a plus!).
On average I find that the smaller the bin array, the better broadcast
approach works, but join
doesn't seem too affected. If I had more time/resource to test this, I'd run lots of simulations to see what the average run time looks like, but for now I'll accept @hoyland's solution.
Still have not sure why the original approach didn't work, so open to comments on that.
Kindly let me know of any issues in my code, or improvements! Thank you both :)
Upvotes: 1
Reputation: 41957
When there is a requirement as yours to check a value in a column with all the rest of the values in that column, collection is the most important. And when there is requirement to check with all the values then it is certain that all the data of that column need to be accumulated in one executor or driver. You cannot avoid the step when there is a requirement as yours.
Now the main part is how you define the rest of the steps to benefit from the parallelization of spark. I would suggest you to broadcast
the collected set (as its distinct data of one column only so they must not be huge) and use a udf
function for checking the gte
condition as below
firstly you can optimize the collection step of yours as
import org.apache.spark.sql.functions._
val collectedData = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
Then you broadcast
the collected set
val broadcastedArray = sc.broadcast(collectedData)
Next step is to define a udf
function and check the gte
condition and return counts
def checkingUdf = udf((data: Double)=> broadcastedArray.value.count(x => x >= data))
and use it as
distDF.withColumn("atRisk", checkingUdf(col("data"))).show(false)
So that finally you should have
+----+----------+----+------+
|data|censorFlag|wins|atRisk|
+----+----------+----+------+
|4.5 |1 |1 |1 |
|0.7 |0 |0 |6 |
|2.3 |0 |0 |3 |
|1.0 |1 |1 |4 |
|0.8 |1 |2 |5 |
|0.8 |1 |2 |5 |
|4.0 |1 |1 |2 |
+----+----------+----+------+
I hope thats the required dataframe
Upvotes: 1
Reputation: 1824
I have not tested this so the syntax may be goofy, but I would do a series of joins:
I believe your first statement is equivalent to this--for each data
value, count how many wins
there are:
val distDF = df.groupBy($"data").agg(sum($"censorFlag").as("wins"))
Then, as you noted, we can build a dataframe of the bins:
val distinctData = df.select($"data".as("dataBins")).distinct()
And then join with a >=
condition:
val atRiskCounts = distDF.join(distinctData, distDF.data >= distinctData.dataBins)
.groupBy($"data", $"wins")
.count()
Upvotes: 1