Reputation: 421
Suppose I had a function to generate a (py)spark data frame, caching the data frame into memory as the last operation.
def gen_func(inputs):
df = ... do stuff...
df.cache()
df.count()
return df
Per my understanding, Spark's caching works as follows:
cache/persist
plus an action (count()
) is called on a data
frame, it is computed from its DAG and cached into memory, affixed
to the object which refers to it.My question is, suppose I use gen_func
to generate a data frame, but then overwrite the original data frame reference (perhaps with a filter
or a withColumn
).
df=gen_func(inputs)
df=df.filter("some_col = some_val")
In Spark, RDD/DF are immutable, so the reassigned df after the filter and the df before the filter refer to two entirely different objects. In this case, the reference to the original df that was cache/counted
has been overwritten. Does that mean that the cached data frame is no longer available and will be garbage collected? Does that mean that the new post-filter df
will compute everything from scratch, despite being generated from a previously cached data frame?
I am asking this because I was recently fixing some out-of-memory issues with my code, and it seems to me that caching might be the problem. However, I do not really understand the full details yet of what are the safe ways to use cache, and how one might accidentally invalidate one's cached memory. What is missing in my understanding? Am I deviating from best practice in doing the above?
Upvotes: 28
Views: 14981
Reputation: 9425
Wanted to make a couple of points to hopefully clarify Spark's behavior with respect to caching.
When you have a
df = ... do stuff...
df.cache()
df.count()
...and then somewhere else in your application
another_df = ... do *same* stuff...
another_df.*some_action()*
..., you would expect another_df
to reuse cached df
dataframe. After all, reusing the result of a prior computation is the objective of caching. Realizing that, Spark developers made a decision to use analyzed logical plans as a "key" to identify cached dataframes, as opposed to relying on mere references from the application side.
In Spark, CacheManager
is the component keeping track of cached computations, in the indexed sequence cachedData
:
/**
* Maintains the list of cached plans as an immutable sequence. Any updates to the list
* should be protected in a "this.synchronized" block which includes the reading of the
* existing value and the update of the cachedData var.
*/
@transient @volatile
private var cachedData = IndexedSeq[CachedData]()
During query planning (in Cache Manager phase), this structure is scanned for all subtrees of a plan being analysed, to see if any of them have already been computed. If a match is found, Spark substitutes this subtree with a corresponding InMemoryRelation
from cachedData
.
cache()
(a simple synonym for persist()
) function stores the dataframes with storage level MEMORY_AND_DISK
by calling cacheQuery(...)
in CacheManager
/**
* Caches the data produced by the logical representation of the given [[Dataset]].
* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
* recomputing the in-memory columnar representation of the underlying table is expensive.
*/
def cacheQuery(...
Note that this is different from RDD caching which uses MEMORY_ONLY
level. Once cached dataframes remain cached either in memory or on local executor disk until they are explicitly unpersist
'ed, or the CacheManager's clearCache()
is called. When executor storage memory fills up completely, cached blocks start being pushed to disk using LRU (least recently used) but never simply "dropped".
Good question, by the way...
Upvotes: 6
Reputation: 42422
I've done a couple of experiments as shown below. Apparently, the dataframe, once cached, remains cached (as shown in getPersistentRDDs
and the query plan - InMemory
etc.), even if all Python reference were overwritten or deleted altogether using del
, and with garbage collection explicitly called.
Experiment 1:
def func():
data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
data.cache()
data.count()
return data
sc._jsc.getPersistentRDDs()
df = func()
sc._jsc.getPersistentRDDs()
df2 = df.filter('col1 != 2')
del df
import gc
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()
df2.select('*').explain()
del df2
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()
Results:
>>> def func():
... data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
... data.cache()
... data.count()
... return data
...
>>> sc._jsc.getPersistentRDDs()
{}
>>> df = func()
>>> sc._jsc.getPersistentRDDs()
{71: JavaObject id=o234}
>>> df2 = df.filter('col1 != 2')
>>> del df
>>> import gc
>>> gc.collect()
93
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{71: JavaObject id=o240}
>>> df2.select('*').explain()
== Physical Plan ==
*(1) Filter (isnotnull(col1#174L) AND NOT (col1#174L = 2))
+- *(1) ColumnarToRow
+- InMemoryTableScan [col1#174L], [isnotnull(col1#174L), NOT (col1#174L = 2)]
+- InMemoryRelation [col1#174L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(1) Project [_1#172L AS col1#174L]
+- *(1) Scan ExistingRDD[_1#172L]
>>> del df2
>>> gc.collect()
85
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{71: JavaObject id=o250}
Experiment 2:
def func():
data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
data.cache()
data.count()
return data
sc._jsc.getPersistentRDDs()
df = func()
sc._jsc.getPersistentRDDs()
df = df.filter('col1 != 2')
import gc
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()
df.select('*').explain()
del df
gc.collect()
sc._jvm.System.gc()
sc._jsc.getPersistentRDDs()
Results:
>>> def func():
... data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
... data.cache()
... data.count()
... return data
...
>>> sc._jsc.getPersistentRDDs()
{}
>>> df = func()
>>> sc._jsc.getPersistentRDDs()
{86: JavaObject id=o317}
>>> df = df.filter('col1 != 2')
>>> import gc
>>> gc.collect()
244
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{86: JavaObject id=o323}
>>> df.select('*').explain()
== Physical Plan ==
*(1) Filter (isnotnull(col1#220L) AND NOT (col1#220L = 2))
+- *(1) ColumnarToRow
+- InMemoryTableScan [col1#220L], [isnotnull(col1#220L), NOT (col1#220L = 2)]
+- InMemoryRelation [col1#220L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(1) Project [_1#218L AS col1#220L]
+- *(1) Scan ExistingRDD[_1#218L]
>>> del df
>>> gc.collect()
85
>>> sc._jvm.System.gc()
>>> sc._jsc.getPersistentRDDs()
{86: JavaObject id=o333}
Experiment 3 (control experiment, to show that unpersist
works)
def func():
data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
data.cache()
data.count()
return data
sc._jsc.getPersistentRDDs()
df = func()
sc._jsc.getPersistentRDDs()
df2 = df.filter('col1 != 2')
df2.select('*').explain()
df.unpersist()
df2.select('*').explain()
Results:
>>> def func():
... data = spark.createDataFrame([[1],[2],[3]]).toDF('col1')
... data.cache()
... data.count()
... return data
...
>>> sc._jsc.getPersistentRDDs()
{}
>>> df = func()
>>> sc._jsc.getPersistentRDDs()
{116: JavaObject id=o398}
>>> df2 = df.filter('col1 != 2')
>>> df2.select('*').explain()
== Physical Plan ==
*(1) Filter (isnotnull(col1#312L) AND NOT (col1#312L = 2))
+- *(1) ColumnarToRow
+- InMemoryTableScan [col1#312L], [isnotnull(col1#312L), NOT (col1#312L = 2)]
+- InMemoryRelation [col1#312L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- *(1) Project [_1#310L AS col1#312L]
+- *(1) Scan ExistingRDD[_1#310L]
>>> df.unpersist()
DataFrame[col1: bigint]
>>> sc._jsc.getPersistentRDDs()
{}
>>> df2.select('*').explain()
== Physical Plan ==
*(1) Project [_1#310L AS col1#312L]
+- *(1) Filter (isnotnull(_1#310L) AND NOT (_1#310L = 2))
+- *(1) Scan ExistingRDD[_1#310L]
To answer the OP's question:
Does that mean that the cached data frame is no longer available and will be garbage collected? Does that mean that the new post-filter df will compute everything from scratch, despite being generated from a previously cached data frame?
The experiments suggest no for both. The dataframe remains cached, is not garbage collected, and the new dataframe is computed using the cached (unreference-able) dataframe, according to the query plan.
Some helpful functions related to cache usage (if you don't want to do it through the Spark UI) are:
sc._jsc.getPersistentRDDs()
, which shows a list of cached RDDs/dataframes, and
spark.catalog.clearCache()
, which clears all cached RDDs/dataframes.
Am I deviating from best practice in doing the above?
I am no authority to judge you on this, but as one of the comments suggested, avoid reassigning to df
because dataframes are immutable. Try to imagine you're coding in scala and you defined df
as a val
. Doing df = df.filter(...)
is impossible. Python can't enforce that per se, but I think the best practice is to avoid overwriting any dataframe variables, so that you can always call df.unpersist()
afterwards if you no longer need the cached results anymore.
Upvotes: 16