tlambert
tlambert

Reputation: 77

embedding pre/post-compute operations in a dask.Array task graph

I'm interested in creating a dask.array.Array that opens and closes a resource before/after compute(). But, I'd like to make no assumptions about how the end-user is going to call compute and I'd like to avoid creating a custom dask Array subclass or proxy object, so I'm trying to embed the operations in the __dask_graph__ underlying the array.

(aside: Please ignore for the moment caveats about using stateful objects in dask, I'm aware of the risks, this question is just about task graph manipulation).

Consider the following class, that simulates a file reader that must be in an open state in order to read a chunk, otherwise it segfaults.

import dask.array as da
import numpy as np

class FileReader:
    _open = True

    def open(self):
        self._open = True

    def close(self):
        self._open = False

    def to_dask(self) -> da.Array:
        return da.map_blocks(
            self._dask_block,
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )

    def _dask_block(self):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

If file stays open, everything is fine, but if one closes the file, the dask array returned from to_dask will fail:

>>> t = FileReader()
>>> darr = t.to_dask()
>>> t.close()
>>> darr.compute()  # RuntimeError: Segfault!

The current task graph looks like:

>>> list(darr.dask)
[
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 0, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 1, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 2, 0, 0),
    ('_dask_block-aa4daac0835bafe001693f9ac085683a', 3, 0, 0)
]

essentially, I'd like to add new task to the beginning, that the _dask_block layer depends on, and a task to the end, that depends on _dask_block.

I tried directly manipulating the HighLevelGraph output of da.map_blocks to add these tasks manually, but found that they were getting pruned during compute optimizations, because darr.__dask_keys__() doesn't contain my keys (and, again, I'd like to avoid subclassing or requiring the end-user to call compute with special optimization flags).

One solution would be to make sure that the _dask_block function passed to map_blocks always opens and closes the underlying resource... but let's assume that the open/close process is relatively slow, effectively destroying performance for single-machine parallelism. So we only want a single open at the beginning, and close at the end.

I can "cheat" a little to include a task that opens my file, by including a new key in my call to map_blocks as follows:

    ...
    
    # new method that will run at beginning of compute()
    def _pre_compute(self):
        was_open = self._open
        if not was_open:
            self.open()
        return was_open

    def to_dask(self) -> da.Array:
        # new task key
        pre_task = 'pre_compute-' + tokenize(self._pre_compute)
        arr = da.map_blocks(
            self._dask_block,
            pre_task,  # add key here so all chunks depend on it
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )
        # add key to HighLevelGraph
        arr.dask.layers[pre_task] = {pre_task: (self._pre_compute,)}
        return da.Array(arr.dask, arr.name, arr.chunks, arr.dtype)

    # add "mock" argument to map_blocks function
    def _dask_block(self, _):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

so far so good, no more RuntimeError... but now I have leaked file handles, since nothing closes it at the end.

What I'd like then, is a task at the end of the graph that depends on both the output of pre_task (i.e. whether the file had to be opened for this compute), and closes the file if it had to be opened.

And here's where I'm stuck, since that post-compute key will get pruned by the optimizer...

is there any way to do this without creating a custom Array subclass that overrides methods like __dask_postcompute__ or __dask_keys__, or requiring the end-user to call compute without optimization?

Upvotes: 1

Views: 298

Answers (1)

Ian Rose
Ian Rose

Reputation: 91

This is a really interesting question. I think you are on the right track with editing the task graph to include tasks for opening and closing your shared resource. But manual graph manipulation is fiddly and difficult to get right.

I think the easiest way to accomplish what you want is to use some of the relatively-recently-added utilities for graph manipulation in dask.graph_manipulation. In particular, I think we want bind, which can be used to add implicit dependencies to a Dask collection, and wait_for, which can be used to ensure that dependents of a collection wait on another unrelated collection.

I took a pass at modifying your example to with these utilities to create a variety of to_dask() which is self-opening and closing:

import dask
import dask.array as da
import numpy as np
from dask.graph_manipulation import bind, checkpoint, wait_on


class FileReader:
    _open = False

    def open(self):
        self._open = True

    def close(self):
        self._open = False

    def to_dask(self) -> da.Array:
        # Delayed version of self.open
        @dask.delayed
        def open_resource():
            self.open()
        # Delayed version of self.close
        @dask.delayed
        def close_resource():
            self.close()
            
        opener = open_resource()
        arr = da.map_blocks(
            self._dask_block,
            chunks=((1,) * 4, 4, 4),
            dtype=float,
        )
        # Make sure the array is dependent on `opener`
        arr = bind(arr, opener)

        closer = close_resource()
        # Make sure the closer is dependent on the array being done
        closer = bind(closer, arr)
        # Make sure dependents of arr happen after `closer` is done
        arr, closer = wait_on(arr, closer)
        return arr

    def _dask_block(self):
        if not self._open:
            raise RuntimeError("Segfault!")
        return np.random.rand(1, 4, 4)

I found it interesting to look at the task graph before and after the manipulations.

Before, it is a relatively straightforward chunked array:

chunked-array-with-no-checkpoints-or-dependencies

But after the manipulations, you can see that the array blocks depend on open_resource, and these then flow into close_resource, which then flows into letting the array chunks into the wider world:

task-graph-visualization-self-opening-and-closing-array

Upvotes: 4

Related Questions