Carpetfizz
Carpetfizz

Reputation: 9149

JIT a least squares loss function in Jax

I have a simple loss function that looks like this

        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))

I would like to optimize over the parameter r and use some static parameters x and y to compute the residual. All parameters in question are DeviceArrays.

In order to JIT this, I tried doing the following

        @partial(jax.jit, static_argnums=(1, 2))
        def loss(r, x, y):
            resid = f(r, x) - y
            return jnp.mean(jnp.square(resid))

but I get this error

jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.

I understand that from #6233 that this is by design but I was wondering what the workaround here is, as this seems like a very common use case where you have some fixed (input, output) training data pairs and some free variable.

Thanks for any tips!

EDIT: this is the error I get when I just try to use jax.jit

jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`

Upvotes: 3

Views: 1043

Answers (1)

jakevdp
jakevdp

Reputation: 86320

It sounds like you're thinking of static arguments as "values that don't vary between computations". In JAX's JIT, static arguments can better be thought of as "hashable compile-time constants". In your case, you don't have hashable compile-time constants; you have arrays, so you can just JIT-compile with no static args:

@jit
def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))

If you really want the JAX machinery to know that your arrays are constant, you can do so by passing them via a closure or a partial; for example:

from functools import partial

def loss(r, x, y):
    resid = f(r, x) - y
    return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))

However, for the type of computation you are doing, where the constants are arrays operated on by JAX array functions, these two approaches lead to basically identical lowered XLA code, so you may as well use the simpler one.

Upvotes: 5

Related Questions