neel g
neel g

Reputation: 1255

Jax: generating random numbers under **JIT**

I have a setup where I need to generate some random number that is consumed by vmap and then lax.scan later on:

def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
    ...
    return num.astype(int)

def forward(key: Array, input: Array) -> Array:
    k = generate_random(key, 1, 5)
    computation = model(.., k, ..)
    ...

# Computing the forward pass
output = jax.vmap(forward, in_axes=.....

But attempting to convert num from a jax.Array to an int32 causes the ConcretizationError.

This can be reproduced through this minimal example:

@jax.jit
def t():
  return jnp.zeros((1,)).item().astype(int)

o = t()
o

JIT requires that all the manipulations be of the Jax type.

But vmap uses JIT implicitly. And I would prefer to keep it for performance reasons.


My Attempt

This was my hacky attempt:

@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
  key, subkey = jax.random.split(key)
  random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
  return random_number.astype(int)

def react_forward(key: Array, input: Array) -> Array:
  k = get_rand_num(key, 1, MAX_ITERS)
  # forward pass the model without tracking grads
  intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
  ...
  return ...

a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

Which involves creating the batch_size # of subkeys to use at every batch during vmap (a.shape[0]) thus getting random numbers.

But it doesn't work, because of the k being casted from jax.Array -> int.

But making these changes:

-  k = get_rand_num(key, 1, MAX_ITERS)
+  k = 5 # any hardcoded int

Works perfectly. Clearly, the sampling is causing the problem here...


Clarifications

To not make this into an X-Y problem I'll clearly define what I want precisely:

I'm implementing a version of stochastic depth; basically, my model's forward pass can accept a depth: int at runtime which is the length of a scan run internally - specifically, the xs = jnp.arange(depth) for the scan.

I want my architecture to flexibly adapt to different depths. Therefore, at training time, I need a way to produce pseudorandom numbers that would equal the depth.

So I require a function, that on every call to it (such is the case in vmap) it returns a different number, sampled within some bound: depth ∈ [1, max_iters].

The function has to be jit-able (implicit requirement of vmap) and has to produce an int - as that's what fed into jnp.arange later (Workarounds that directly get generate_random to produce an Array of jnp.arange(depth) without converting to a static value might be possible)

(I have no idea honestly how others do this; this seems like a common enough want, especially if one's dealing with sampling during train time)

I've attached the error traceback generated by my "hacky solution attempt" if that helps...

---------------------------------------------------------------------------

ConcretizationTypeError                   Traceback (most recent call last)

<ipython-input-32-d6ff062f5054> in <cell line: 16>()
     14 a = jnp.zeros((300, 32)).astype(int)
     15 rndm_keys = jax.random.split(key, a.shape[0])
---> 16 jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

    [... skipping hidden 3 frame]

4 frames

<ipython-input-32-d6ff062f5054> in react_forward(key, input)
      8   k = get_rand_num(key, 1, MAX_ITERS)
      9   # forward pass the model without tracking grads
---> 10   intermediate_array = jax.lax.stop_gradient(model(input, iters_to_do=k))
     11   # n-k passes, but track the gradient this time
     12   return model(input, MAX_ITERS - k, intermediate_array)

    [... skipping hidden 12 frame]

<ipython-input-22-4760d53eb89c> in __call__(self, input, iters_to_do, prev_thought)
     71       #interim_thought = self.main_block(interim_thought)
     72 
---> 73     interim_thought = self.iterate_for_steps(interim_thought, iters_to_do, x)
     74 
     75     return self.out_head(interim_thought)

    [... skipping hidden 12 frame]

<ipython-input-22-4760d53eb89c> in iterate_for_steps(self, interim_thought, iters_to_do, x)
     56         return self.main_block(interim_thought), None
     57 
---> 58     final_interim_thought, _ = jax.lax.scan(loop_body, interim_thought, jnp.arange(iters_to_do))
     59     return final_interim_thought
     60 

/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)
   2286     util.check_arraylike("arange", start)
   2287     if stop is None and step is None:
-> 2288       start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
   2289     else:
   2290       start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'")

/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in concrete_or_error(force, val, context)
   1379       return force(val.aval.val)
   1380     else:
-> 1381       raise ConcretizationTypeError(val, context)
   1382   else:
   1383     return force(val)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'
This BatchTracer with object id 140406974192336 was created on line:
  <ipython-input-32-d6ff062f5054>:8 (react_forward)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Really appreciate you helping me out here. Cheers!

Upvotes: 1

Views: 1516

Answers (1)

jakevdp
jakevdp

Reputation: 86443

The random numbers you are generating are traced; in other words, they are known only at runtime (your mental model of "runtime" should be "operations running on the XLA device").

Python integers, on the other hand, are static; in other words, they must be defined at compile time (your mental model of "compile time" should be "operations that happen before runtime values are known").

With this framing, it's clear that you cannot convert a traced value to a static Python integer within jit, vmap or any other JAX transform, because static values must be known before the traced values are determined. Where this comes up in your minimal example is in the call to .item(), which attempts to cast a (traced) JAX array to a (static) Python scalar.

You can fix this by avoiding this cast. Here is a new version of your function that returns a zero-dimensional integer array, which is how JAX encodes an integer scalar at runtime:

@jax.jit
def t():
  return jnp.zeros((1,)).astype(int).reshape(())

That said, the fact that you are so concerned with creating an integer from an array makes me think that your model function requires its second argument to be static, and unfortunately the above won't help you in that case. For the reasons discussed above it is impossible to convert a traced value within a JAX transformation into a static value.


Edit: the issue you're running into is the fact that JAX arrays must have static shapes. In your code, you're generating random integers at runtime, and attempting to pass them to jnp.arange, which would result in a dynamically-shaped array. It is not possible to execute such code within transformations like jit or vmap.

Fixing this usually involves writing your code in a way that supports the dynamic computation size (for example, creating a padded array of a maximum size, or using jax.fori_loop in place of jax.scan).

Upvotes: 2

Related Questions