Reputation: 53
I am working on building a computation graph with Dask. Some of the intermediate values will be used multiple times, but I would like those calculations to only run once. I must be making a trivial mistake, because that's not what happens. Here is a minimal example:
In [1]: import dask
dask.__version__
Out [1]: '1.0.0'
In [2]: class SumGenerator(object):
def __init__(self):
self.sources = []
def register(self, source):
self.sources += [source]
def generate(self):
return dask.delayed(sum)([s() for s in self.sources])
In [3]: sg = SumGenerator()
In [4]: @dask.delayed
def source1():
return 1.
@dask.delayed
def source2():
return 2.
@dask.delayed
def source3():
return 3.
In [5]: sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)
In [6]: sg.generate().visualize()
Sadly I am unable to post the resulting graph image, but basically I see two separate nodes for the function source1
that was registered twice. Therefore the function is called twice. I would rather like to have it called once, the result remembered and added twice in the sum. What would be the correct way to do that?
Upvotes: 3
Views: 490
Reputation: 7123
You need to call the dask.delayed
decorator by passing the pure=True
argument.
From the dask delayed docs
delayed also accepts an optional keyword pure. If False, then subsequent calls will always produce a different Delayed
If you know a function is pure (output only depends on the input, with no global state), then you can set pure=True.
So using that
import dask
class SumGenerator(object):
def __init__(self):
self.sources = []
def register(self, source):
self.sources += [source]
def generate(self):
return dask.delayed(sum)([s() for s in self.sources])
@dask.delayed(pure=True)
def source1():
return 1.
@dask.delayed(pure=True)
def source2():
return 2.
@dask.delayed(pure=True)
def source3():
return 3.
sg = SumGenerator()
sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)
sg.generate().visualize()
Output and Graph
Using print(dask.compute(sg.generate()))
gives (7.0,)
which is the same as the one you wrote but without the extra node as seen in the image.
Upvotes: 3