Reputation: 1518
I am trying to accelerate code with numba
(currently I am using numba 0.45.1
) and come across a problem with boolean indexing. The code is as follows:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
@njit
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
new_arr = func(n_arr)
As soon as I run the code, I get the following message
TypingError: Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 2d, C), array(bool, 2d, C), float64)
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
In definition 2:
All templates rejected with literals.
In definition 3:
All templates rejected without literals.
In definition 4:
All templates rejected with literals.
In definition 5:
All templates rejected without literals.
In definition 6:
All templates rejected with literals.
In definition 7:
All templates rejected without literals.
In definition 8:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
In definition 9:
TypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from C:\Users\User\Anaconda3\lib\site-packages\numba\typing\arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of setitem at C:/Users/User/Desktop/all python file/5.5.5/numba index broadcasting2.py (29)
Note that the (29)
at the last line corresponds to line 29, which is result[idx] = 10.1
, the line I tried to assign value to result whose index is idx
, a 2-D boolean index.
I'd like to explain that including that statement result[idx] = 10.1
inside @njit
is a must. As much as I want to exclude this statement in @njit
, I can't, because this line sits right in the middle of a code I am working on.
If I insist to include the assignment statement result[idx] = 10.1
inside @njit
, what exactly needs to be changed in order to make it work? If possible I'd like to see some code example that involves 2-D boolean index inside @njit
that can be run.
Thank you
Upvotes: 10
Views: 3508
Reputation: 68682
Numba does not currently support fancy indexing with a 2D array. See:
https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#array-access
However you can get equivalent behavior by re-writing your function with for-loops explicitly rather than relying on broadcasting:
from numba import njit
import numpy as np
n_max = 1000
n_arr = np.hstack((np.arange(1,3),
np.arange(3,n_max, 3)
))
def func(arr):
idx = np.arange(arr[-1]).reshape((-1,1)) < arr -2
result = np.zeros(idx.shape)
result[idx] = 10.1
return result
@njit
def func2(arr):
M = arr[-1]
N = arr.shape[0]
result = np.zeros((M, N))
for i in range(M):
for j in range(N):
if i < arr[j] - 2:
result[i, j] = 10.1
return result
new_arr = func(n_arr)
new_arr2 = func2(n_arr)
print(np.allclose(new_arr, new_arr2)) # True
On my machine, and with the example inputs you provided, func2
is about 3.5x faster than func
.
Upvotes: 7