Reputation: 2011
Inside a numba jitted nopython function, I need to index an array with the values inside of an another array. Both arrays are numpy arrays floats.
For example
@numba.jit("void(f8[:], f8[:], f8[:])", nopython=True)
def need_a_cast(sources, indices, destinations):
for i in range(indices.size):
destinations[i] = sources[indices[i]]
My code is different, but let's assume the problem is reproducible by this stupid example (i.e., I cannot have indices of type int). AFAIK, i cannot use int(indices[i]) nor indices[i].astype("int") inside of nopython jit function.
How do I do this?
Upvotes: 6
Views: 5314
Reputation: 152775
If you really cannot use int(indices[i])
(it works for JoshAdel and also for me) you should be able to work around it with math.trunc
or math.floor
:
import math
...
destinations[i] = sources[math.trunc(indices[i])] # truncate (py2 and py3)
destinations[i] = sources[math.floor(indices[i])] # round down (only py3)
math.floor
works only for Python3 as far as I know because it returns a float
in Python2. But math.trunc
on the other hand rounds up for negative values.
Upvotes: 3
Reputation: 68722
Using numba 0.24 at least, you can do a simple cast:
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def need_a_cast(sources, indices, destinations):
for i in range(indices.size):
destinations[i] = sources[int(indices[i])]
sources = np.arange(10, dtype=np.float64)
indices = np.arange(10, dtype=np.float64)
np.random.shuffle(indices)
destinations = np.empty_like(sources)
print indices
need_a_cast(sources, indices, destinations)
print destinations
# Result
# [ 3. 2. 8. 1. 5. 6. 9. 4. 0. 7.]
# [ 3. 2. 8. 1. 5. 6. 9. 4. 0. 7.]
Upvotes: 3