Reputation: 2387
THis code :
@jit(nopython=True)
def foo(x):
return x.sum(axis=(1,2))
x=np.linspace(0,1)
x=x.reshape(5,5,-1)
print(foo(x))
returns this error :
NotImplementedError: No definition for lowering array.sum(array(float64, 3d, C), Tuple(Literal[int](1), Literal[int](2))) -> array(float64, 2d, C)
It seems that numba supports np.sum when axis parameter is only an integer, not a tuple of integer (https://numpy.org/doc/stable/reference/generated/numpy.ndarray.sum.html#numpy.ndarray.sum).
So, I use this workaround return x.sum(axis=1).sum(axis=1)
but it is not a good solution if I consider code optimization.
Does it exist another solution or must I wait for a future numba version ?
Upvotes: 2
Views: 822
Reputation: 249333
Just reshape your array to one dimension more than your desired result. In your example:
x.sum(axis=(1,2))
can be replaced by:
x.reshape(5,-1).sum(axis=1)
which produces the same result and can be executed by Numba.
Upvotes: 3