Val
Val

Reputation: 7023

dask.distributed: wait for all tasks to finish before shutdown (without futures)

Tldr:

I'm using fire_and_forget to execute tasks on a dask.distributed cluster, so I don't maintain a future for each task. How can I wait until they are all done before the cluster gets shut down?

Details:

I have a workflow that creates a xarray dataset which is persisted on the cluster. Once the computations are done, I want to save the time slices individually and move on to the next dataset.

Until now, I've been using a delayed function and collected a list of delayed tasks which I then passed on to client.compute - this way I was sure everything was done before I moved on to the next dataset. The downside is, that all is blocked until every last file got written.

Now I'm looking into fire_and_forget to be able to start the computations on the next dataset while the files of the previous one are still being written.

I'm planning to wait for each dataset to be completed before I start the fire_and_forget tasks, so they should have plenty of time to complete. The only issue I've encountered is, that when processing the last dataset, there's no more waiting and the cluster gets shut down after the last fire_and_forget call, even though the processes are still running.

So is there any way to tell the client it needs to block until all is completed?

Or am I maybe not properly understanding the use of fire_and_forget and should stay with my previous approach?

Here's an example code that simulates the workflow - it does 10 iterations (simulating the different datasets) and then writes the first 10 time slices to pickle files. So in the end I'm expecting 100 pickle files on disk, which is not the case.

import pickle
import random
from time import sleep

from dask import delayed
from dask.distributed import LocalCluster, Client, wait, fire_and_forget
import xarray as xr

@delayed
def dump_delayed(x, fn):
    with open(fn, "wb") as f:
        random.seed(42)
        sleep(random.randint(1,2))
        pickle.dump(x, f)

TARGET = "/home/jovyan/"

def main():

    cluster = LocalCluster(n_workers=2, ip="0.0.0.0")
    client = Client(cluster)

    ds = xr.tutorial.open_dataset("rasm")

    for it in range(1,10):
        print("Iteration %s" % it)

        # simulating the processing and persisting
        ds2 = (ds*it).chunk({"time": 1}).persist()
        _ = wait(ds2)
    
        for ii in range(10):
            
            fn = TARGET + f"temp{ii}_{it}.pkl"
            xx = ds2.isel(time=ii)
            
            f = client.persist(dump_delayed(xx, fn))
            fire_and_forget(f)
            

if __name__ == "__main__":
    main()

Upvotes: 3

Views: 1950

Answers (1)

SultanOrazbayev
SultanOrazbayev

Reputation: 16551

Not sure if this qualifies for a solution, but fire_and_forget is for a specific use case where you do not want to track the status of the task. If you are interested in the status of the tasks, it's better to use the regular future.

Upvotes: 1

Related Questions