Reputation: 2129
I data stored in an xarray dataarray with dimension (nt, nb, nx, ny)
I wrote some code that calculates some quantities cell wise (e.g., for each cell in nx
and ny
) between two different values of the 0-th dimension. My processing can be done independently in nt
, but also in nx
and ny
. I don't know how to get this to run in parallel using dask.
The following example demonstrates what I want to do (my calculations are more complex than those used here). However, this runs sequentially, and I'd like to parallelise it, and make use of the chunk structure in the data.
import numpy as np
import xarray as xr
import xarray.tutorial
from numba import njit, float64, float32
from itertools import product
@njit('Tuple((float32[:, :],float32[:,:]))(float32[:, :, :], float32[:, :,:])')
def do_smthg(ar1, ar2):
n1, n2, n3 = ar1.shape
outa = np.zeros((n2, n3), dtype=np.float32)
outb = np.zeros((n2, n3), dtype=np.float32)
for i in range(n1):
for j in range(n2):
outa[i,j] = np.sum(ar1[:, i,j] - ar2[:, i,j])
outb[i,j] = np.sum(ar1[:, i,j] + ar2[:, i,j])
return outa, outb
da = xarray.tutorial.load_dataset("era5-2mt-2019-03-uk.grib")
da = da.chunk("auto")
F = {}
for (t1,tt1), (t2, tt2) in product(da.t2m.groupby("time.day"),
da.t2m.groupby("time.day")):
# t1 and t2 are timesteps. Calculate for pairs t1 and all times after.
if t2 > t1:
F[(t1, t2)] = do_smthg(tt1.values, tt2.values)
One way to parallelise this would be to have a dask client available, and map things over, but this requires a lot of thinking and data shifting:
from distributed import LocalCluster, Client
cluster = LocalCluster()
client = Client(cluster)
F = {}
for (t1,tt1), (t2, tt2) in product(da.t2m.groupby("time.day"),
da.t2m.groupby("time.day")):
if t2 > t1:
F[(t1, t2)] = client.submit(do_smthg, tt1.values, tt2.values)
F = {k:v.result() for k,v in F.items()}
This kind of works, but I am unsure there's any cleverl parallelisation going on. Besides, it needs to shift loads of data around. It looks like the kind of thing that some xarray/dask jiggery pokery should make very efficient. I would want to run this on a large dask cluster, where my datasets would be very large (but chunked).
Using map_blocks
isn't also clear:
# template output dataset
out = xr.Dataset(
data_vars={"outa":(["lat", "lon"], np.random.rand(33, 49)),
"outb":(["lat", "lon"], np.random.rand(33, 49))})
out.coords["lat"] = da.coords["latitude"].values
out.coords["lon"] = da.coords["longitude"].values
out = out.chunk("auto")
F = {}
for (t1,tt1), (t2, tt2) in product(da.t2m.groupby("time.day"),
da.t2m.groupby("time.day")):
# t1 and t2 are timesteps. Calculate for pairs t1 and all times after.
if t2 > t1:
F[(t1, t2)] = tt1.drop("time").map_blocks(do_smthg, args=[tt2.drop("time")], template=out)
F[(1,5)].outb.values
This results in an error when running the numba code
TypeError: No matching definition for argument type(s) pyobject, pyobject
If I remove the numba wrapper, and just use the vanilla slow Python function, this runs until the end and returns this message
~/mambaforge/lib/python3.9/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
~/mambaforge/lib/python3.9/site-packages/xarray/core/parallel.py in _wrapper(func, args, kwargs, arg_is_array, expected)
286
287 # check all dims are present
--> 288 missing_dimensions = set(expected["shapes"]) - set(result.sizes)
289 if missing_dimensions:
290 raise ValueError(
AttributeError: 'numpy.ndarray' object has no attribute 'sizes'
So there's something strange going on here with passing different variables.
Upvotes: 0
Views: 359
Reputation: 11
I believe you need to make the output of do_smthg
match the template.
I experienced the same AttributeError: 'numpy.ndarray' object has no attribute 'sizes'
using xr.map_blocks when the output of the function was a numpy array. When I converted the output to be in the expected xarray template, the map_blocks code executed as expected.
Upvotes: 1