Reputation: 175
I am trying to use some Jax code in a Pallas kernel but for some reason my code does not work anymore.
import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array
from jax.experimental import sparse
key = jax.random.PRNGKey(52)
other = jax.random.normal(key, (10, 10))
diags = jax.random.normal(key, (3, 10))
offsets = (-2, 1, 2)
def dia_matmul_kernel(diags_ref, offsets, other_ref, o_ref):
diags, other = diags_ref[...], other_ref[...]
N = other.shape[0]
out = jnp.zeros((N, N))
print(offsets)
for offset, diag in zip(offsets, diags):
start = jax.lax.max(0, offset)
end = min(N, N + offset)
top = max(0, -offset)
bottom = top + end - start
out = out.at[top:bottom, :].add(
diag[start:end, None] * other[start:end, :]
)
o_ref[...] = out
@functools.partial(jax.jit, static_argnums=(1, ))
def dia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Array:
return pl.pallas_call(
dia_matmul_kernel,
out_shape=jax.ShapeDtypeStruct(other.shape, other.dtype)
)(diags, offsets ,other)
dia_matmul(diags, offsets,other)
I understand that is not best practice to print stuff in a Jax JIT function but when I print my offsets
, which should be kept static from the static_argnums=(1,)
, it says:
(Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>, Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>, Traced<MemRef<None>{int32[]}>with<DynamicJaxprTrace(level=3/0)>)
I don't understand why that is the case, I'm new to Jax and Pallas so I'm not yet fully confident with this whole Tracing concept. Also the last operation of the for loop with the out
is not working so if anyone also has an idea :D
Many thanks!
Upvotes: 1
Views: 791
Reputation: 86443
The arguments passed to pallas_call
will always be traced, regardless of whether or not they are static before being passed to pallas_call
. This is true any time you pass arguments to a JAX primitive, or a transformed function: all inputs will be traced unless explicitly marked as static in the function you are calling.
pallas_call
doesn't currently have any way of marking static arguments, and will trace all arguments passed to the wrapped function. If you want some arguments to be static, you should be able to do this by closing over them in the function that you pass to pallas_call
:
@functools.partial(jax.jit, static_argnums=(1, ))
def dia_matmul(diags: Array, offsets: tuple[int],other:Array) -> Array:
return pl.pallas_call(
lambda diags, other, o_ref: dia_matmul_kernel(diags, offsets, other, o_ref),
out_shape=jax.ShapeDtypeStruct(other.shape, other.dtype)
)(diags, other)
Upvotes: 1