Nin17
Nin17

Reputation: 3492

Numba promotes types differently to numpy

When adding (or multiplying, dividing, subtracting etc…) a python float to a numpy array, in numpy the dtype of the array is preserved, whereas in numba the array is promoted to float64. How can I modify the overload of ndarray.__add__ etc… to change the dtype of the python float to match that of the array so that the result has the same dtype?

Ideally, I don't want to have to modify my functions, rather just implement a new overload of the underlying addition etc functions, as there are many instances of this in my code.

Code to demonstrate the issue, would like consistency with numpy in a function decorated with njit:

import numpy as np
import numba as nb

def func(array):
    return array + 1.0

numba_func = nb.njit(func)

a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)

for i in (a_f64, a_f32):
    print(i.dtype)
    print(func(i).dtype)
    print(numba_func(i).dtype, end="\n\n")

Output (with numpy 2.1.3 and numba 0.61.0):

float64
float64
float64

float32
float32
float64

Upvotes: 1

Views: 90

Answers (1)

roganjosh
roganjosh

Reputation: 13185

It comes from this

Numpy will most often return a float64 as a result of a computation with mixed integer and floating-point operands (a typical example is the power operator **). Numba by contrast will select the highest precision amongst the floating-point operands, so for example float32 ** int32 will return a float32, regardless of the input values. This makes performance characteristics easier to predict, but you should explicitly cast the input to float64 if you need the extra precision.

You can fix your example simply by using:

import numpy as np
import numba as nb

def func(array):
    return array + np.float32(1.0)

numba_func = nb.njit(func)

a_f64 = np.ones(1, dtype=np.float64)
a_f32 = np.ones(1, dtype=np.float32)
# print(a_f32)
for i in (a_f64, a_f32):
    print(i.dtype)
    print(func(i).dtype)
    print(numba_func(i).dtype, end="\n\n")

The problem is that your 1.0 is being interpreted as 64-bit and it upcasts the lot.

I don't think this makes a whole lot of sense, but it fixes the upcasting to 64-bit.

This is my output:

float64
float64
float64

float32
float32
float32

This is with numpy '1.26.4' and numba '0.60.0' in Python 3.12.0. I don't think I've solved the whole problem here.

Upvotes: 1

Related Questions