Reputation: 1255
I have a setup where I need to generate some random number that is consumed by vmap
and then lax.scan
later on:
def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
...
return num.astype(int)
def forward(key: Array, input: Array) -> Array:
k = generate_random(key, 1, 5)
computation = model(.., k, ..)
...
# Computing the forward pass
output = jax.vmap(forward, in_axes=.....
But attempting to convert num
from a jax.Array
to an int32
causes the ConcretizationError
.
This can be reproduced through this minimal example:
@jax.jit
def t():
return jnp.zeros((1,)).item().astype(int)
o = t()
o
JIT requires that all the manipulations be of the Jax type.
But vmap
uses JIT implicitly. And I would prefer to keep it for performance reasons.
This was my hacky attempt:
@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
key, subkey = jax.random.split(key)
random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
return random_number.astype(int)
def react_forward(key: Array, input: Array) -> Array:
k = get_rand_num(key, 1, MAX_ITERS)
# forward pass the model without tracking grads
intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
...
return ...
a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
Which involves creating the batch_size
# of subkeys to use at every batch during vmap
(a.shape[0]
) thus getting random numbers.
But it doesn't work, because of the k
being casted from jax.Array -> int
.
But making these changes:
- k = get_rand_num(key, 1, MAX_ITERS) + k = 5 # any hardcoded int
Works perfectly. Clearly, the sampling is causing the problem here...
To not make this into an X-Y problem I'll clearly define what I want precisely:
I'm implementing a version of stochastic depth; basically, my model
's forward pass can accept a depth: int
at runtime which is the length of a scan
run internally - specifically, the xs = jnp.arange(depth)
for the scan
.
I want my architecture to flexibly adapt to different depths. Therefore, at training time, I need a way to produce pseudorandom numbers that would equal the depth
.
So I require a function, that on every call to it (such is the case in vmap
) it returns a different number, sampled within some bound: depth ∈ [1, max_iters]
.
The function has to be jit
-able (implicit requirement of vmap
) and has to produce an int
- as that's what fed into jnp.arange
later (Workarounds that directly get generate_random
to produce an Array
of jnp.arange(depth)
without converting to a static value might be possible)
(I have no idea honestly how others do this; this seems like a common enough want, especially if one's dealing with sampling during train time)
I've attached the error traceback generated by my "hacky solution attempt" if that helps...
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-32-d6ff062f5054> in <cell line: 16>()
14 a = jnp.zeros((300, 32)).astype(int)
15 rndm_keys = jax.random.split(key, a.shape[0])
---> 16 jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
[... skipping hidden 3 frame]
4 frames
<ipython-input-32-d6ff062f5054> in react_forward(key, input)
8 k = get_rand_num(key, 1, MAX_ITERS)
9 # forward pass the model without tracking grads
---> 10 intermediate_array = jax.lax.stop_gradient(model(input, iters_to_do=k))
11 # n-k passes, but track the gradient this time
12 return model(input, MAX_ITERS - k, intermediate_array)
[... skipping hidden 12 frame]
<ipython-input-22-4760d53eb89c> in __call__(self, input, iters_to_do, prev_thought)
71 #interim_thought = self.main_block(interim_thought)
72
---> 73 interim_thought = self.iterate_for_steps(interim_thought, iters_to_do, x)
74
75 return self.out_head(interim_thought)
[... skipping hidden 12 frame]
<ipython-input-22-4760d53eb89c> in iterate_for_steps(self, interim_thought, iters_to_do, x)
56 return self.main_block(interim_thought), None
57
---> 58 final_interim_thought, _ = jax.lax.scan(loop_body, interim_thought, jnp.arange(iters_to_do))
59 return final_interim_thought
60
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in arange(start, stop, step, dtype)
2286 util.check_arraylike("arange", start)
2287 if stop is None and step is None:
-> 2288 start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
2289 else:
2290 start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'")
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in concrete_or_error(force, val, context)
1379 return force(val.aval.val)
1380 else:
-> 1381 raise ConcretizationTypeError(val, context)
1382 else:
1383 return force(val)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
It arose in the jnp.arange argument 'stop'
This BatchTracer with object id 140406974192336 was created on line:
<ipython-input-32-d6ff062f5054>:8 (react_forward)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Really appreciate you helping me out here. Cheers!
Upvotes: 1
Views: 1516
Reputation: 86443
The random numbers you are generating are traced; in other words, they are known only at runtime (your mental model of "runtime" should be "operations running on the XLA device").
Python integers, on the other hand, are static; in other words, they must be defined at compile time (your mental model of "compile time" should be "operations that happen before runtime values are known").
With this framing, it's clear that you cannot convert a traced value to a static Python integer within jit
, vmap
or any other JAX transform, because static values must be known before the traced values are determined. Where this comes up in your minimal example is in the call to .item()
, which attempts to cast a (traced) JAX array to a (static) Python scalar.
You can fix this by avoiding this cast. Here is a new version of your function that returns a zero-dimensional integer array, which is how JAX encodes an integer scalar at runtime:
@jax.jit
def t():
return jnp.zeros((1,)).astype(int).reshape(())
That said, the fact that you are so concerned with creating an integer from an array makes me think that your model
function requires its second argument to be static, and unfortunately the above won't help you in that case. For the reasons discussed above it is impossible to convert a traced value within a JAX transformation into a static value.
Edit: the issue you're running into is the fact that JAX arrays must have static shapes. In your code, you're generating random integers at runtime, and attempting to pass them to jnp.arange
, which would result in a dynamically-shaped array. It is not possible to execute such code within transformations like jit
or vmap
.
Fixing this usually involves writing your code in a way that supports the dynamic computation size (for example, creating a padded array of a maximum size, or using jax.fori_loop
in place of jax.scan
).
Upvotes: 2