Reputation: 33
I have multiple Spark jobs which share a part of the dataflow graph including an expensive shuffle operation. If I persist that RDD, I see huge improvement (22x) as expected.
However, even when I keep the storage level of those RDDs as NONE
, I still see upto 4x improvement just by sharing the RDDs among jobs.
Why? I am under the assumption that Sark always recompute RDDs with storage level NONE and those are not subject to eviction/spilling.
My Spark version is 3.3.1. Showing the code is difficult as the code is spread in multiple files in a bigger system. I am essentially doing the following:
If I persist the RDDs in the second step (by calling rdd.persist(StorageLevel.MEMORY_AND_DISK)
, I see a huge improvement. But even if I just reuse the same RDD (storage level NONE
), I still see improvement.
[1] LIMA: Fine-grained Lineage Tracing and Reuse in Machine Learning Systems. Arnab Phani, Benjamin Rath, Matthias Boehm. SIGMOD 2021
Upvotes: 1
Views: 108
Reputation: 5973
If we have a look at the source code for the rdd.persist(StorageLevel)
method, we see the following:
def persist(newLevel: StorageLevel): this.type = {
if (isLocallyCheckpointed) {
// This means the user previously called localCheckpoint(), which should have already
// marked this RDD for persisting. Here we should override the old storage level with
// one that is explicitly requested by the user (after adapting it to use disk).
persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
} else {
persist(newLevel, allowOverride = false)
}
}
So that calls a persist
method with an extra input argument. It looks like this:
/**
* Mark this RDD for persisting using the specified level.
*
* @param newLevel the target storage level
* @param allowOverride whether to override any existing level with the new one
*/
private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
// TODO: Handle changes of StorageLevel
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
throw SparkCoreErrors.cannotChangeStorageLevelError()
}
// If this is the first time this RDD is marked for persisting, register it
// with the SparkContext for cleanups and accounting. Do this only once.
if (storageLevel == StorageLevel.NONE) {
sc.cleaner.foreach(_.registerRDDForCleanup(this))
sc.persistRDD(this)
}
storageLevel = newLevel
this
}
In there, we see something interesting. If the current storageLevel
(not the new one) == StorageLevel.NONE
, we're going to registerRDDForCleanup
and persistRDD
on this RDD.
Now, the default value for storageLevel
is StorageLevel.NONE
. That means that your case (calling persist
on an unpersisted RDD) falls under this category.
So we found out that calling rdd.persist(StorageLevel.NONE)
actually does something with your RDD! Let's have a look at both of these operations.
registerRDDForCleanup
is a method of the ContextCleaner
class. It looks like this:
/** Register an RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]): Unit = {
registerForCleanup(rdd, CleanRDD(rdd.id))
}
// some other code between here that I removed for this explanation
/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
}
So this method actually adds a cleanup task (associated with this RDD) to some buffer called the referenceBuffer
. That looks like this:
/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
* have not been handled by the reference queue.
*/
private val referenceBuffer =
Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap)
So as the code comments are saying, this referenceBuffer
is a buffer to ensure that tasks don't get garbage collected too soon. So your RDDs are getting less garbage collected, which improves your performance!
The second method that was called on our RDD is the persistRDD
method. I won't go into too much detail (since it is less important here) but this method (of Sparkcontext.scala) basically adds this RDD to a Map
in which the SparkContext keeps track of all persisted RDDs.
We could go deeper in this investigation, but that would become impractical to write/read. I think this level of abstraction is enough to understand that calling rdd.persist(StorageLevel)
actually does something to make your RDDs not be garbage collected too soon!
Hope this helps :)
Upvotes: 1