oneloop
oneloop

Reputation: 197

How to JIT code involving masked arrays without NonConcreteBooleanIndexError?

Follow-up to Count onto 2D JAX coordinates of another 2D array.

I have a function (schematically):

# works
def predict(model, x1, x2, x2_mask):
    y = somefunc(x1)
    z = y.at[x2[x2_mask][:, 0], x2[x2_mask][:, 1]].get()
    w = something.at[jnp.where(x2_mask)].set(z)
    return w

The intuition of what I'm trying to do is that x1 is some encoding of a game board (grid) and somefunc are some convolutional operations. Then I want to read off from the resulting convolutions the values (vectors) at coordinates x2 masked by x2_mask.

x2 is (N, 2) and x2_mask is (N,). z is (something,2) where something depends on the content of the mask. w has the same content as z except stretched to known size (N, 2) and with a placeholder value in the remaining places.

It works.

I try and jit it:

# fails NonConcreteBooleanIndexError
@nnx.jit
def predict(model, x1, x2, x2_mask):
    y = somefunc(x1)
    z = y.at[x2[x2_mask][:, 0], x2[x2_mask][:, 1]].get()
    w = something.at[jnp.where(x2_mask)].set(z)
    return w

First off, notice I'm using nnx.jit because that's why I see here. I guess if the function has state you have to use nnx.jit instead of jax.jit although I don't understand the details. I suspect the former is just syntactic sugar for the latter. Also note that I'm putting this on the predict step, not on the train step as on the example.

But then I get this error. From what I read, jit operations can't contain variable size arrays, so I'm guessing this is related to x2[x2_mask]. And indeed, if I remove that part jit works:

# works
@nnx.jit
def predict(model, x1, x2, x2_mask):
    y = somefunc(x1)
    return y

The inputs x1, x2, x2_mask all have known fixed sizes. The size of y and the size of the output w can be inferred from the sizes of the inputs. The size of z depends on how many True you have in the mask.

What confuses me here is that I would've expected that compiled code doesn't really have to materialize x2[x2_mask], so why does it matter that it has variable size? It's just a for loop that skips for elements. There's really no uncertainty about how much memory will be used. Same thing about z: it's only used to determine which values to set into a known size array, so why does it matter if its size isn't known?

Am I making some mistake and, if not, is there still an alternative way to achieve the same?

I don't strictly need to jit my code, at this stage I'd just like it to work. But I came across this issue when looking into auto-batching.

batched_predict = jax.vmap(predict, in_axes=(None, 0, 0, 0))
batched_predictmodel(model, x1s, x2s, x2_masks)

So this is causing at least two obstacles.

I believe the answer might be in Apply function only on slice of array under jit, I'm trying to work out the details.

I was wondering if the w= line would also make jit fail, in addition to the z= line and it looks like it does. If I remove the mask from the z= line just for the purposes of testing:

@nnx.jit
def predict(model, x1, x2, x2_mask):
    y = somefunc(x1)
    z = y.at[x2[:, 0], x2[:, 1]].get()
    w = something.at[jnp.where(x2_mask)].set(z)
    return w

This also fails, now with this error. something.at[jnp.where(...)] seems like it should be extremely commonplace.

Upvotes: 1

Views: 77

Answers (2)

jakevdp
jakevdp

Reputation: 86443

In cases of boolean mask indexing, you can often re-express it in a JIT-compatible way using the three-term jax.numpy.where function.

In this case, you should replace this:

w = something.at[jnp.where(x2_mask)].set(z)

with this:

w = jnp.where(x2_mask, z, something)

Upvotes: 0

oneloop
oneloop

Reputation: 197

I found a horrible hack that allows me to do make jit functions which operate on masked vectors. I'm hoping that @jakevdp is gonna show up and show me how it's done.

Ok so the idea is that even though x2[x2_mask] has size which is dependent on the values of x2_mask, for the purposes of the operation x.at[x2[x2_mask][:, 0], x2[x2_mask][:, 1]].add(1) once it's compiled, this is just a for loop which skips based on the values of x2_mask. Even though the intermediate variables's sizes are not known at compile time, the memory layout to carry out this operation, is known at compile time. But if jax needs intermediate variables with statically known size then we'll fill the values of the masked array we don't care about with coordinates we don't care about. But since in principle we care about all coordinates of x1, then we first enlarge it.

Lets create a minimal example that shows the problem:

x = jnp.zeros((5,5))
coords = jnp.array([
    [1,2],
    [2,3],
    [1,2],
    [1,2],
])
coords_mask = jnp.array([True, True, False, True])

def testfunc(x, coords, coords_mask):
    coords_masked = coords[coords_mask]
    return x.at[coords_masked[:, 0], coords_masked[:, 1]].add(1)

testfunc(x, coords, coords_mask)

This works, outputs

[[0., 0., 0., 0., 0.],
 [0., 0., 2., 0., 0.],
 [0., 0., 0., 1., 0.],
 [0., 0., 0., 0., 0.],
 [0., 0., 0., 0., 0.]]

note that one of the [1,2] has been masked out and the other two were counted twice.

But this doesn't work:

@jax.jit
def testfunc(x, coords, coords_mask):
    coords_masked = coords[coords_mask]
    return x.at[coords_masked[:, 0], coords_masked[:, 1]].add(1)

testfunc(x, coords, coords_mask)  # NonConcreteBooleanIndexError

So here's a horrible hack around:

@jax.jit
def testfunc(x, coords, coords_mask):
    len_0, len_1 = x.shape

    # enlarge x by 1 in axis=1
    x = jnp.concatenate([x, jnp.zeros((len_0, 1))], axis=1)

    # prepare mask coordinates so that the False points to a position in the enlarged x array
    default = jnp.full(coords.shape, fill_value=jnp.array([0, len_1]))
    mask_repeated = jnp.repeat(coords_mask.reshape((coords.shape[0],1)), coords.shape[1], axis=1)
    coords_masked = jnp.where(mask_repeated, coords, default)

    # scatter coords onto enlarged x
    x = x.at[coords_masked[:, 0], coords_masked[:, 1]].add(1)
    
    # take a slice of x
    x = x[:, :len_1]
    return x

testfunc(x, coords, coords_mask)  # works

Upvotes: 1

Related Questions