Reputation: 4864
I have been playing with jax
lately, and it is very impressive, but then the following set of experiments confused me greatly:
First, we set up the timer utility:
import time
def timefunc(foo, *args):
tic = time.perf_counter()
tmp = foo(*args)
toc = time.perf_counter()
print(toc - tic)
return tmp
Now, let’s see what happens when we compute the eigenvalues of a random symmetric matrix matrix, thus (jnp
is jax.numpy
, so the eigh
is done on the GPU)
def jfunc(n):
tmp = np.random.randn(n, n)
return jnp.linalg.eigh(tmp + tmp.T)
def nfunc(n):
tmp = np.random.randn(n, n)
return np.linalg.eigh(tmp + tmp.T)
Now for the timings (the machine is an nVidia DGX box, so the GPU is an A100, while the CPUs are some AMD EPYC2 parts.
>>> e1 = timefunc(nfunc, 10)
0.0002442029945086688
>>> e2 = timefunc(jfunc, 10)
0.013523647998226807
>>> e1 = timefunc(nfunc, 100)
0.11742364699603058
>>> e2 = timefunc(jfunc, 100)
0.11005625998950563
>>> e1 = timefunc(nfunc, 1000)
0.6572738009999739
>>> e2 = timefunc(jfunc, 1000)
0.5530761769914534
>>> e1 = timefunc(nfunc, 10000)
36.22587636699609
>>> e2 = timefunc(jfunc, 10000)
8.867857075005304
You will notice that the crossover is somewhere around 1000. Initially, I thought this was because of the overhead of moving stuff to/from the GPU, but if you define yet another function:
def jjfunc(n):
key=jax.random.PRNGKey(0)
tmp = jax.random.normal(key, [n, n])
return jnp.linalg.eigh(tmp + tmp.T)
>>> e1=timefunc(jjfunc, 10)
0.01886096798989456
>>> e1=timefunc(jjfunc, 100)
0.2756766739912564
>>> e1=timefunc(jjfunc, 1000)
0.7205733209993923
>>> e1=timefunc(jjfunc, 10000)
6.8624101399909705
Note that the small examples are actually (much) slower than moving the numpy
array to the GPU and back.
So, my question is: what is going on, and is there a silver bullet? Is this a jax
implementation bug?
Upvotes: 1
Views: 835
Reputation: 86320
I don't think your timings are reflective of actual JAX vs. numpy performance, for a few reasons:
block_until_ready()
method to ensure you are timing the computation rather than the dispatch.eigh
are JIT-compiled by default, the first time you run them for a given size will incur the one-time compilation cost. Subsequent runs will be faster as JAX caches previous compilations.def transfer(n):
tmp = np.random.randn(n, n)
return jnp.array(tmp).block_until_ready()
timefunc(transfer, 10000);
# 4.600406924000026
jjfunc
combines the eigh
call with the jax.random.normal
call. The latter is slower than numpy's random number generation, and I believe is dominating the difference for small n
.time.time
for profiling Python code can give you misleading results. Modules like timeit
are much better for this kind of thing, particularly when you're dealing with microbenchmarks that complete in fractions of a second.If you're interested in accurate benchmarks of JAX vs. Numpy versions of algorithms, I'd suggest isolating exactly the operations you're interested in benchmarking (i.e. generate the data & do any device transfer outside the benchmarks). Read up on the advice in Asynchronous Dispatch in JAX as it relates to benchmarking, and check out Python's timeit
Docs for tips on getting accurate timings of small code snippets (though I find the %timeit
magic more convenient if working IPython or Jupyter notebook).
Upvotes: 3