Reputation: 1772
I am very new to numba. I have an n
-sized "index" array (configuration
) containing integers in [0,m)
and a "mapping" array (phase_values
) of size m
. What I want is to map each element of the configuration
array to the corresponding value from phase_values
using numba.guvectorize.
In numpy, the function looks like:
def configuration_to_phase_shifts_numpy(configuration, phase_values):
map_states_to_phase_shifts = np.vectorize(lambda s: phase_values[s])
return map_states_to_phase_shifts(configuration)
Following numba's documentation, the equivalent vectorized code I understand I need is:
@numba.guvectorize([(numba.int64[:], numba.complex128[:], numba.complex128[:])], '(n),(m)->(n)',nopython=True)
def configuration_to_phase_shifts(configuration, phase_values, phase):
for i in range(configuration.shape[0]):
phase[i] = phase_values[configuration[i]]
However, when compiling during the first call of the function:
phase = configuration_to_phase_shifts(configuration, phase_values)
numba throws the following error:
Untyped global name 'configuration_to_phase_shifts': cannot determine Numba type of <class 'numpy.ufunc'>
So, (apart from the obvious question of how to make it work), I cannot understand why my function is "untyped" - when I have used the guvectorize
decorator, have declared an output argument and have a symbolic output layout defined?
The above error arises when I am calling the vectorized function within a @numba.njit
decorated function, e.g.:
@numba.njit
def foo():
phase_values = np.array([1,2,3], dtype=complex)
configuration = np.array([0,0,1,1,2,2])
phase = configuration_to_phase_shifts(configuration_to_phase_shifts)
do_stuff(phase)
>>>foo()
Untyped global name 'configuration_to_phase_shifts': cannot determine Numba type of <class 'numpy.ufunc'>
File "<ipython-input>", line 5:
def foo():
<source elided>
configuration = np.array([0,0,1,1,2,2])
phase = configuration_to_phase_shifts(configuration_to_phase_shifts)
^
The numpy indexing approach from the answer however solves the problem
Upvotes: 0
Views: 162
Reputation: 3437
Your code seems to work fine without any changes using Numba 0.53 and Python 3.8 on this example:
>>> configuration = np.array([1,2,0], dtype=int)
>>> phase_values = np.array([4,6,8,1], dtype=complex)
>>> configuration_to_phase_shifts(configuration, phase_values)
array([6.+0.j, 8.+0.j, 4.+0.j])
However, the same result can be obtained using pure numpy, which doesn't need any further optimization:
>>> phase_values[configuration]
array([6.+0.j, 8.+0.j, 4.+0.j])
This should be a comment instead of an answer, but formatting is harder in comments.
Upvotes: 1