bsaoptima
bsaoptima

Reputation: 175

Jax traces a static Argument

I am trying to use some Jax code in a Pallas kernel but for some reason my code does not work anymore.

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array
from jax.experimental import sparse

key = jax.random.PRNGKey(52)
other = jax.random.normal(key, (10, 10))
diags = jax.random.normal(key, (3, 10))
offsets = (-2, 1, 2)

def dia_matmul_kernel(diags_ref, offsets, other_ref, o_ref):
    diags, other = diags_ref[...], other_ref[...]
    N = other.shape[0]
    out = jnp.zeros((N, N))
    print(offsets)
    for offset, diag in zip(offsets, diags):
        start = jax.lax.max(0, offset)
        end = min(N, N + offset)
        top = max(0, -offset)
        bottom = top + end - start
        out = out.at[top:bottom, :].add(
            diag[start:end, None] * other[start:end, :]
        )
    o_ref[...] = out

@functools.partial(jax.jit, static_argnums=(1, ))
def dia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Array:    
    return pl.pallas_call(
        dia_matmul_kernel,
        out_shape=jax.ShapeDtypeStruct(other.shape, other.dtype)
    )(diags, offsets ,other)
dia_matmul(diags, offsets,other)

I understand that is not best practice to print stuff in a Jax JIT function but when I print my offsets, which should be kept static from the static_argnums=(1,), it says:

(Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>, Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>, Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>)

I don't understand why that is the case, I'm new to Jax and Pallas so I'm not yet fully confident with this whole Tracing concept. Also the last operation of the for loop with the out is not working so if anyone also has an idea :D

Many thanks!

Upvotes: 1

Views: 791

Answers (1)

jakevdp
jakevdp

Reputation: 86443

The arguments passed to pallas_call will always be traced, regardless of whether or not they are static before being passed to pallas_call. This is true any time you pass arguments to a JAX primitive, or a transformed function: all inputs will be traced unless explicitly marked as static in the function you are calling.

pallas_call doesn't currently have any way of marking static arguments, and will trace all arguments passed to the wrapped function. If you want some arguments to be static, you should be able to do this by closing over them in the function that you pass to pallas_call:

@functools.partial(jax.jit, static_argnums=(1, ))
def dia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Array:
    return pl.pallas_call(
        lambda diags, other, o_ref: dia_matmul_kernel(diags, offsets, other, o_ref),
        out_shape=jax.ShapeDtypeStruct(other.shape, other.dtype)
    )(diags, other)

Upvotes: 1

Related Questions