Jose
Jose

Reputation: 2129

Parallel processing using dask, numba and xarray

I data stored in an xarray dataarray with dimension (nt, nb, nx, ny) I wrote some code that calculates some quantities cell wise (e.g., for each cell in nx and ny) between two different values of the 0-th dimension. My processing can be done independently in nt, but also in nx and ny. I don't know how to get this to run in parallel using dask.

The following example demonstrates what I want to do (my calculations are more complex than those used here). However, this runs sequentially, and I'd like to parallelise it, and make use of the chunk structure in the data.

import numpy as np
import xarray as xr
import xarray.tutorial
from numba import njit, float64, float32
from itertools import product

@njit('Tuple((float32[:, :],float32[:,:]))(float32[:, :, :], float32[:, :,:])')
def do_smthg(ar1, ar2):
    n1, n2, n3 = ar1.shape
    outa = np.zeros((n2, n3), dtype=np.float32)
    outb = np.zeros((n2, n3), dtype=np.float32)
    for i in range(n1):
        for j in range(n2):
            outa[i,j] = np.sum(ar1[:, i,j] - ar2[:, i,j])
            outb[i,j] = np.sum(ar1[:, i,j] + ar2[:, i,j])
    return outa, outb
    
da = xarray.tutorial.load_dataset("era5-2mt-2019-03-uk.grib")
da = da.chunk("auto")
F = {}
for (t1,tt1), (t2, tt2) in product(da.t2m.groupby("time.day"),
                           da.t2m.groupby("time.day")):
    # t1 and t2 are timesteps. Calculate for pairs t1 and all times after.
    if t2 > t1:
        F[(t1, t2)] = do_smthg(tt1.values, tt2.values)
    
    

One way to parallelise this would be to have a dask client available, and map things over, but this requires a lot of thinking and data shifting:

from distributed import LocalCluster, Client
cluster = LocalCluster()
client = Client(cluster)
F = {}
for (t1,tt1), (t2, tt2) in product(da.t2m.groupby("time.day"),
                           da.t2m.groupby("time.day")):
    if t2 > t1:
        F[(t1, t2)] = client.submit(do_smthg, tt1.values, tt2.values)
F = {k:v.result() for k,v in F.items()}

This kind of works, but I am unsure there's any cleverl parallelisation going on. Besides, it needs to shift loads of data around. It looks like the kind of thing that some xarray/dask jiggery pokery should make very efficient. I would want to run this on a large dask cluster, where my datasets would be very large (but chunked).

Using map_blocks isn't also clear:

# template output dataset
out = xr.Dataset(
    data_vars={"outa":(["lat", "lon"], np.random.rand(33, 49)),
               "outb":(["lat", "lon"], np.random.rand(33, 49))})
out.coords["lat"] = da.coords["latitude"].values
out.coords["lon"] = da.coords["longitude"].values
out = out.chunk("auto")

F = {}
for (t1,tt1), (t2, tt2) in product(da.t2m.groupby("time.day"),
                           da.t2m.groupby("time.day")):
    # t1 and t2 are timesteps. Calculate for pairs t1 and all times after.
    if t2 > t1:
        F[(t1, t2)] = tt1.drop("time").map_blocks(do_smthg, args=[tt2.drop("time")], template=out)
F[(1,5)].outb.values

This results in an error when running the numba code

TypeError: No matching definition for argument type(s) pyobject, pyobject

If I remove the numba wrapper, and just use the vanilla slow Python function, this runs until the end and returns this message

~/mambaforge/lib/python3.9/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    117         # temporaries by their reference count and can execute certain
    118         # operations in-place.
--> 119         return func(*(_execute_task(a, cache) for a in args))
    120     elif not ishashable(arg):
    121         return arg

~/mambaforge/lib/python3.9/site-packages/xarray/core/parallel.py in _wrapper(func, args, kwargs, arg_is_array, expected)
    286 
    287         # check all dims are present
--> 288         missing_dimensions = set(expected["shapes"]) - set(result.sizes)
    289         if missing_dimensions:
    290             raise ValueError(

AttributeError: 'numpy.ndarray' object has no attribute 'sizes'

So there's something strange going on here with passing different variables.

Upvotes: 0

Views: 359

Answers (1)

Websy
Websy

Reputation: 11

I believe you need to make the output of do_smthg match the template.

I experienced the same AttributeError: 'numpy.ndarray' object has no attribute 'sizes' using xr.map_blocks when the output of the function was a numpy array. When I converted the output to be in the expected xarray template, the map_blocks code executed as expected.

Upvotes: 1

Related Questions