akkh
akkh

Reputation: 139

How to reduce JAX compile time when using for loop?

This is a basic example.

@jax.jit
def block(arg1, arg2):
   for x1 in range(cons1):
       for x2 in range(cons2):
          for x3 in range(cons3):
             --do something--
   return result

When cons are small, the compile-time is around a minute. With larger cons, compile time is much higher—10s of minutes. And I need even higher cons. What can be done? From what I am reading, the loops are the cause. They are unrolled at compile time. Are there any workarounds? There is also jax.fori_loop. But I don't understand how to use it. There is jax.experimental.loops module, but again I'm not able to understand it.

I am very new to all this. Hence, all help is appreciated. If you can provide some examples of how to use jax loops, that will be much appreciated.

Also, what is an ok compile time? Is it ok for it to be in minutes? In one of the examples, compile time is 262 seconds and remaining runs are ~0.1-0.2 seconds.

Any gain in runtime is overshadowed by the compile time.

Upvotes: 4

Views: 7928

Answers (2)

jakevdp
jakevdp

Reputation: 86443

JAX's JIT compiler flattens all Python loops. To see what I mean, take a look at this simple function run through jax.make_jaxpr, which is a way to examine how JAX's tracer interprets python code (see Understanding Jaxprs for more):

import jax

def f(x):
  for i in range(5):
    x += i
  return x

print(jax.make_jaxpr(f)(0))
# { lambda  ; a.
#   let b = add a 0
#       c = add b 1
#       d = add c 2
#       e = add d 3
#       f = add e 4
#   in (f,) }

Notice that the loop is flattened: every step becomes an explicit operation sent to the XLA compiler. The XLA compile time increases as you increase the number of operations in the function, so it makes sense that a triply-nested for-loop would lead to long compile times.

So, how to address this? Well, unfortunately the answer depends on what your --do something-- is doing, so I can't guess that.

In general, the best option is to use vectorized array operations rather than loops over the values in those vectors; for example, here is a very slow way of adding two vectors:

import jax.numpy as jnp

def f_slow(x, y):
  z = []
  for xi, yi in zip(xi, yi):
    z.append(xi + yi)
  return jnp.array(z)

and here is a much faster way to do the same thing:

def f_fast(x, y):
  return x + y

If your operations don't lend themselves to vectorization, another option is to use lax control flow operators in place of the for loops: this will push the loop down into XLA. This can have quite good performance on CPU, but is slower on accelerators when compared to equivalent vectorized array operations.

For more discussion on JAX and Python control flow statements (such as for, if, while, etc.), see 🔪 JAX - The Sharp Bits 🔪: Control Flow.

Upvotes: 5

dankal444
dankal444

Reputation: 4158

I am not sure if this is will be the same as with numba, but this might be similar case.

When I use numba.jit compiler and have big data input, first I compile function on some small example data, then use it.

Pseudo-code:

func_being_compiled(small_amount_of_data)  # compile-only purpose
func_being_compiled(large_amount_of_data)

Upvotes: 0

Related Questions