bluenote10
bluenote10

Reputation: 26729

How to achieve incremental caching without data duplication in Dask?

I'm trying to find the equivalent of Spark's unpersist in Dask. My need for an explicit unpersist arises in a situation where:

A basic example would look like:

def iterative_algorithm(df, num_iterations):

    for iteration in range(num_iterations):

        # Transformation logic requiring e.g. map_partitions
        def mapper(df):
            # ...
            return df

        df = df.map_partitions(mapper)
        df = df.persist()
        # Now I would like to explicitly unpersist the old snapshot

    return df

In Spark, the problem could be solved by releasing the old snapshots explicitly. Apparently Dask does not have an explicit unpersist but handles the problem via reference counting of the underlying futures. This means that the example above would duplicate the data, because the calling context holds references to the old futures, while the sub-function holds references to the modified persist. In my actual use case, there are several nested levels of such transformation calls, causing the data to duplicate even multiple times.

Is there a way to solve iterative caching without any additional copies?

Upvotes: 3

Views: 892

Answers (2)

MRocklin
MRocklin

Reputation: 57319

You could write a release function as follows:

from distributed.client import futures_of

def release(collection):
    for future in futures_of(collection):
        future.release()

This will only release the current instance. If you have multiple instances of these futures lying around you might have to call it a few times or add a loop like the following:

while future.client.refcount[future.key] > 0:

But generally calling this multiple times seems unwise in case you have other copies floating around with reason.

Upvotes: 3

bluenote10
bluenote10

Reputation: 26729

I'll post some ideas how to solve this, but I'm still looking for better alternatives.

Due to the reference counting, it is tricky to avoid the copies, but there are possibilities. The problem is a result of the caller holding a reference to the original df and the sub-function creating new instances via df = df.<method> calls. To solve the problem we would have to make the reference to df itself mutable. Unfortunately Python in general does not allow to mutate the reference of function arguments.

Solution 1: Naive mutable reference

The most simple way to work-around that limitation is to wrap the df into a list or dict. In this case the sub-function can modify the external reference, e.g. by:

df_list[0] = df_list[0].map_partitions(mapper)
df_list[0] = df_list[0].persist()

However this is syntactically awkward and one has to be very careful, because simplifying the syntax via df = df_list[0] again creates new references to the underlying futures, which can cause data duplication.

Solution 2: Wrapper-based mutable reference

Improving on that, one could write a small wrapper class, which holds a reference to a dataframe. Passing around this wrapper, the sub-functions can mutate the reference. To improve on the syntax issue one might consider if the wrapper should delegate functionality to the dataframe automatically or inherit from it. Overall this solution also doesn't feel right.

Solution 3: Explicit mutation

To avoid the syntax issues of the other solutions I currently prefer the following variant, which effectively simulates mutable versions of map_partitions and persist via an inplace modification of the original df instance.

def modify_inplace(old_df, new_df):
    # Currently requires accessing private fields of a DataFrame, but
    # maybe this could be officially supported by Dask.
    old_df.dask = new_df.dask
    old_df._meta = new_df._meta
    old_df._name = new_df._name
    old_df.divisions = new_df.divisions


def iterative_algorithm(df, num_iterations):

    for iteration in range(num_iterations):

        def mapper(df):
            # Actual transform logic...
            return df

        # Simulate mutable/in-place map_partitions
        new_df = df.map_partitions(mapper)
        modify_inplace(df, new_df)

        # Simulate mutable/in-place persist
        new_df = df.persist()
        modify_inplace(df, new_df)

    # Technically no need to return, because all operations were in-place
    return df

This works reasonably well for me, but requires to follow these rules carefully:

  • Replace all immutable calls like df = df.<method> by the pattern above.
  • Pay attention to creating references to df. For instance, using a variable like some_col = df["some_sol"] for syntactical convenience requires to del some_col before calling persist. Otherwise the reference stored withing some_col will again cause data duplication.

Upvotes: 3

Related Questions