Igor Rivin
Igor Rivin

Reputation: 4864

GPU and `jax` performance mysteries

I have been playing with jax lately, and it is very impressive, but then the following set of experiments confused me greatly:

First, we set up the timer utility:

import time
def timefunc(foo, *args):
   tic = time.perf_counter()
   tmp = foo(*args)
   toc = time.perf_counter()
   print(toc - tic)
   return tmp

Now, let’s see what happens when we compute the eigenvalues of a random symmetric matrix matrix, thus (jnp is jax.numpy, so the eigh is done on the GPU)

def jfunc(n):
  tmp = np.random.randn(n, n)
  return jnp.linalg.eigh(tmp + tmp.T)

def nfunc(n):
  tmp = np.random.randn(n, n)
  return np.linalg.eigh(tmp + tmp.T)

Now for the timings (the machine is an nVidia DGX box, so the GPU is an A100, while the CPUs are some AMD EPYC2 parts.

>>> e1 = timefunc(nfunc, 10)
0.0002442029945086688
>>> e2 = timefunc(jfunc, 10)
0.013523647998226807
>>> e1 = timefunc(nfunc, 100)
0.11742364699603058
>>> e2 = timefunc(jfunc, 100)
0.11005625998950563
>>> e1 = timefunc(nfunc, 1000)
0.6572738009999739
>>> e2 = timefunc(jfunc, 1000)
0.5530761769914534
>>> e1 = timefunc(nfunc, 10000)
36.22587636699609
>>> e2 = timefunc(jfunc, 10000)
8.867857075005304

You will notice that the crossover is somewhere around 1000. Initially, I thought this was because of the overhead of moving stuff to/from the GPU, but if you define yet another function:

def jjfunc(n):
  key=jax.random.PRNGKey(0)
  tmp = jax.random.normal(key, [n, n])
  return jnp.linalg.eigh(tmp + tmp.T)


>>> e1=timefunc(jjfunc, 10)
0.01886096798989456
>>> e1=timefunc(jjfunc, 100)
0.2756766739912564
>>> e1=timefunc(jjfunc, 1000)
0.7205733209993923
>>> e1=timefunc(jjfunc, 10000)
6.8624101399909705

Note that the small examples are actually (much) slower than moving the numpy array to the GPU and back.

So, my question is: what is going on, and is there a silver bullet? Is this a jax implementation bug?

Upvotes: 1

Views: 835

Answers (1)

jakevdp
jakevdp

Reputation: 86320

I don't think your timings are reflective of actual JAX vs. numpy performance, for a few reasons:

  • JAX's computation model uses Asynchronous Dispatch, which means that JAX operations return before the computation is finished. As mentioned at that link, you can use the block_until_ready() method to ensure you are timing the computation rather than the dispatch.
  • Because operations like eigh are JIT-compiled by default, the first time you run them for a given size will incur the one-time compilation cost. Subsequent runs will be faster as JAX caches previous compilations.
  • Your computations are indeed being foiled by device transfer costs. It's easiest to see if you measure it directly:
    def transfer(n):
      tmp = np.random.randn(n, n)
      return jnp.array(tmp).block_until_ready()
    timefunc(transfer, 10000);
    # 4.600406924000026
    
  • Your jjfunc combines the eigh call with the jax.random.normal call. The latter is slower than numpy's random number generation, and I believe is dominating the difference for small n.
  • Unrelated to JAX, but in general using time.time for profiling Python code can give you misleading results. Modules like timeit are much better for this kind of thing, particularly when you're dealing with microbenchmarks that complete in fractions of a second.

If you're interested in accurate benchmarks of JAX vs. Numpy versions of algorithms, I'd suggest isolating exactly the operations you're interested in benchmarking (i.e. generate the data & do any device transfer outside the benchmarks). Read up on the advice in Asynchronous Dispatch in JAX as it relates to benchmarking, and check out Python's timeit Docs for tips on getting accurate timings of small code snippets (though I find the %timeit magic more convenient if working IPython or Jupyter notebook).

Upvotes: 3

Related Questions