Reputation: 203
How to speed up a funtion with numba when input and return are dictionaries?
I'm familiar with using numba for functions that accept numbers and return arrays, like this:
@numba.jit('float64[:](int32,int32)',nopython=True)
def f(a, b):
# returns array 1d array
Now I have a function that accepts and returns dictionaries. How can I apply numba here?
def collocation(aeolus_data,val_data):
...
return sample_aeolus, sample_valdata
Upvotes: 17
Views: 28486
Reputation: 8801
The support for Dictionary has now been added in Numba version 43.0
. Although it quite limited (does not support list and set as key/values). You can however read the updated documentation here for more info.
Here is an example
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
# First create a dictionary using Dict.empty()
# Specify the data types for both key and value pairs
# Dict with key as strings and values of type float array
dict_param1 = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64[:],
)
# Dict with keys as string and values of type float
dict_param2 = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)
# Type-expressions are currently not supported inside jit functions.
float_array = types.float64[:]
@njit
def add_values(d_param1, d_param2):
# Make a result dictionary to store results
# Dict with keys as string and values of type float array
result_dict = Dict.empty(
key_type=types.unicode_type,
value_type=float_array,
)
for key in d_param1.keys():
result_dict[key] = d_param1[key] + d_param2[key]
return result_dict
dict_param1["hello"] = np.asarray([1.5, 2.5, 3.5], dtype='f8')
dict_param1["world"] = np.asarray([10.5, 20.5, 30.5], dtype='f8')
dict_param2["hello"] = 1.5
dict_param2["world"] = 10
final_dict = add_values(dict_param1, dict_param2)
print(final_dict)
# Output : {hello: [3. 4. 5.], world: [20.5 30.5 40.5]}
Link to Google colab notebook.
References:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict
Upvotes: 35