Stef1611
Stef1611

Reputation: 2387

Numba. How to write np.sum with a tuple axis parameter?

THis code :

@jit(nopython=True)
def foo(x):
    return x.sum(axis=(1,2))

x=np.linspace(0,1)
x=x.reshape(5,5,-1)
print(foo(x))

returns this error :

NotImplementedError: No definition for lowering array.sum(array(float64, 3d, C), Tuple(Literal[int](1), Literal[int](2))) -> array(float64, 2d, C)

It seems that numba supports np.sum when axis parameter is only an integer, not a tuple of integer (https://numpy.org/doc/stable/reference/generated/numpy.ndarray.sum.html#numpy.ndarray.sum). So, I use this workaround return x.sum(axis=1).sum(axis=1) but it is not a good solution if I consider code optimization.

Does it exist another solution or must I wait for a future numba version ?

Upvotes: 2

Views: 822

Answers (1)

John Zwinck
John Zwinck

Reputation: 249333

Just reshape your array to one dimension more than your desired result. In your example:

x.sum(axis=(1,2))

can be replaced by:

x.reshape(5,-1).sum(axis=1)

which produces the same result and can be executed by Numba.

Upvotes: 3

Related Questions