Reputation: 2215
I am using JAX, and I want to perform an operation like
@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
This cannot be performed under jit
. Is there a way of doing this with jax.ops
or jax.lax
?
I thought of using jax.ops.index_update(x, idx, y)
but I cannot find a way of computing y
without incurring in the same problem again.
Upvotes: 7
Views: 4377
Reputation: 86513
The previous answer by @rvinas using dynamic_slice
works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where
. For example:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
Upvotes: 7
Reputation: 11895
It seems there are two issues in your implementation. First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable (i.e. the contents of the array cannot be changed).
You can overcome the two problems by combining static_argnums
and jax.lax.dynamic_update_slice
. Here is an example:
def other_fun(x):
return x + 1
@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [1 2 3 3 4]
Essentially, the example above uses static_argnums
to indicate that the function should be recompiled for different index
values and jax.lax.dynamic_update_slice
creates a copy of x
with updated values at :len(update)
.
Upvotes: 3