tt293
tt293

Reputation: 510

Using dask delayed with functions returning lists

I am trying to use dask.delayed to build up a task graph. This mostly works quite nicely, but I regularly run into situations like this, where I have a number of delayed objects that have a method returning a list of objects of a length that is not easily computed from information I have available at this point:

items = get_collection() # known length

def do_work(item):
    # get_list_of_things returns list of "unknown" length
    return map(lambda x: x.DoStuff(), item.get_list_of_things())

results = [delayed(do_work(x)) for x in items]

This gives a

TypeError: Delayed objects of unspecified length are not iterable

Is there any way in dask to work around this issue, preferably without having to call .compute() on the intermediate results, as that would destroy most of the upside of having a task graph? It basically means that the graph cannot be fully resolved until after some of its steps have run, but the only thing that is variable is the width of a parallel section, it doesn't change the structure or the depth of the graph.

Upvotes: 3

Views: 2402

Answers (1)

MRocklin
MRocklin

Reputation: 57271

Unfortunately if you want to call an individual function on each of the elements in your list then that is part of the structure of your graph, and must be known at graph construction time if you want to use dask.delayed.

In general I see two options:

  1. Don't make an individual task for every element of your list, but rather make a task for the first 10%, the second 10%, etc.. This is the same approach taken in dask.bag, which also handles parallelism with an unknown number of elements (which might be worth considering.

    http://dask.pydata.org/en/latest/bag.html

  2. Switch to the real-time concurrent.futures interface, and wait on the result of your list before submitting more work

    from dask.distributed import Client
    client = Client()
    list_future = client.submit(do_work, *args)
    len_future = client.submit(len, list_future)
    
    n = len_future.result()  # wait until the length is computed
    
    futures = [client.submit(operator.getitem, list_future, i) for i in range(n)]
    
    ... do more stuff with futures
    

    http://dask.pydata.org/en/latest/futures.html

Upvotes: 3

Related Questions