Reputation: 418
Is it possible to make CPU only reductions with JAX comparable to Numba in terms of computation time?
The compilers come straight from conda
:
$ conda install -c conda-forge numba jax
Here is a 1-d NumPy array example
import numpy as np
import numba as nb
import jax as jx
@nb.njit
def reduce_1d_njit_serial(x):
s = 0
for xi in x:
s += xi
return s
@jx.jit
def reduce_1d_jax_serial(x):
s = 0
for xi in x:
s += xi
return s
N = 2**10
a = np.random.randn(N)
Using timeit
on the following
np.add.reduce(a)
gives 1.99 µs ...
reduce_1d_njit_serial(a)
gives 1.43 µs ...
reduce_1d_jax_serial(a).item()
gives 23.5 µs ...
Note that jx.numpy.sum(a)
and using jx.lax.fori_loop
gives comparable (marginally slower) comp. times to reduce_1d_jax_serial
.
It seems there is a better way to craft the reduction for XLA.
EDIT: compile times were not included as a print statement proceeded to check results.
Upvotes: 4
Views: 9373
Reputation: 86443
When performing these kinds of microbenchmarks with JAX, you have to be careful to ensure you're measuring what you think you're measuring. There are some tips in the JAX Benchmarking FAQ. Implementing some of these best practices, I find the following for your benchmarks:
import jax.numpy as jnp
# Native jit-compiled XLA sum
jit_sum = jx.jit(jnp.sum)
# Avoid including device transfer cost in the benchmarks
a_jax = jnp.array(a)
# Prevent measuring compilation time
_ = reduce_1d_njit_serial(a)
_ = reduce_1d_jax_serial(a_jax)
_ = jit_sum(a_jax)
%timeit np.add.reduce(a)
# 100000 loops, best of 5: 2.33 µs per loop
%timeit reduce_1d_njit_serial(a)
# 1000000 loops, best of 5: 1.43 µs per loop
%timeit reduce_1d_jax_serial(a_jax).block_until_ready()
# 100000 loops, best of 5: 6.24 µs per loop
%timeit jit_sum(a_jax).block_until_ready()
# 100000 loops, best of 5: 4.37 µs per loop
You'll see that for these microbenchmarks, JAX is a few milliseconds slower than both numpy and numba. So does this mean JAX is slow? Yes and no; you'll find a more complete answer to that question in JAX FAQ: is JAX faster than numpy?. The short summary is that this computation is so small that the differences are dominated by Python dispatch time rather than time spent operating on the array. The JAX project has not put much effort into optimizing for Python dispatch of microbenchmarks: it's not all that important in practice because the cost is incurred once per program in JAX, as opposed to once per operation in numpy.
Upvotes: 7