Federico Taschin
Federico Taschin

Reputation: 2215

Apply function only on slice of array under jit

I am using JAX, and I want to perform an operation like

@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x

This cannot be performed under jit. Is there a way of doing this with jax.ops or jax.lax? I thought of using jax.ops.index_update(x, idx, y) but I cannot find a way of computing y without incurring in the same problem again.

Upvotes: 7

Views: 4377

Answers (2)

jakevdp
jakevdp

Reputation: 86513

The previous answer by @rvinas using dynamic_slice works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where. For example:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x, index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask, other_fun(x), x)

x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]

Upvotes: 7

rvinas
rvinas

Reputation: 11895

It seems there are two issues in your implementation. First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable (i.e. the contents of the array cannot be changed).

You can overcome the two problems by combining static_argnums and jax.lax.dynamic_update_slice. Here is an example:

def other_fun(x):
    return x + 1

@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
    update = other_fun(x[:index])
    return jax.lax.dynamic_update_slice(x, update, (0,))

x = jnp.arange(5)
print(fun(x, 3))  # prints [1 2 3 3 4]

Essentially, the example above uses static_argnums to indicate that the function should be recompiled for different index values and jax.lax.dynamic_update_slice creates a copy of x with updated values at :len(update).

Upvotes: 3

Related Questions