Michael Sohnen
Michael Sohnen

Reputation: 980

numba @jit(nopython=True): Refactor function that uses dictionary with lists as values

I have several functions that I want to use numba @jit(nopython=True) for, but they all rely on the following function:

def getIslands(labels2D,ignoreSea=True):
    islands = {}
    width = labels2D.shape[1]
    height = labels2D.shape[0]
    for x in range(width):
        for y in range(height):
            label = labels2D[y,x]
            if ignoreSea and label == -1:
                continue
            if label in islands:
                islands[label].append((x,y))
            else:
                islands[label] = [(x,y)]
    return islands

Is there any way to redesign this function so it's compatible with numba @jit(nopython=True)? The JIT fails since the function uses features of python dictionaries that are not supported (i.e., dictionaries containing lists as values.)

numba==0.52.0

Upvotes: 1

Views: 145

Answers (1)

Jérôme Richard
Jérôme Richard

Reputation: 50308

Dictionaries and lists are not very user-friendly in Numba yet. You first need to declare the type of the dictionary values (outside the function):

import numba as nb

intTupleList = nb.types.List(nb.types.UniTuple(nb.int_, 2))

Then you can create an empty typed dictionary in the function using nb.typed.typeddict.Dict.empty. The same thing applies for a list with nb.typed.typedlist.List. Here is how:

@nb.njit('(int_[:,:], bool_)')
def getIslands(labels2D,ignoreSea=True):
    islands = nb.typed.typeddict.Dict.empty(key_type=nb.int_, value_type=intTupleList)
    width = labels2D.shape[1]
    height = labels2D.shape[0]
    for x in range(width):
        for y in range(height):
            label = labels2D[y,x]
            if ignoreSea and label == -1:
                continue
            if label in islands:
                islands[label].append((np.int_(x),np.int_(y)))
            else:
                islands[label] = nb.typed.typedlist.List([(np.int_(x),np.int_(y))])
    return islands

This is a bit sad that Numba cannot infer the type of the list [(x, y)] yet since it is a bit painful to use nb.typed.typedlist.List (especially with the additional cast that are required because of the mismatch between the nb.int_ type and loop iterators that are nb.int64 on 64-bit machines.

Upvotes: 1

Related Questions