Reputation: 806
I'm using jax.jit
on a computation in which a bunch of coefficients for an FFT are calculated. The general flow is:
@jax.jit(static_argnums=(1,),
in_shardings=(x_sharding,),
out_shardings=(x_sharding,))
def my_function(x, n_bins):
coefficients = calc_fft_coefficients(n_bins)
return crunch_numbers(x, coefficients)
I was wondering a few things:
jax.jit
, will the arrays computed the function be computed once (as the computation and shape never chnages) or will it be recomputed every time the function is called? The arrays do not depend on any non-staic inputs to the function.jax.jit
-ed does recompute these arrays, I can calculate them outside of the function. But then I'm not sure how to handle this with sharding? For example, if my x_sharding
is along the batch axis, then if I understand correctly, the function will expect anything with that sharding to be a multiple of the batch size. So would I use a different sharding for my coefficients (that are independent of batch size). If so, what would an appropriate sharding be? Would I even need a sharding, or is it possible to pass non-sharded jax arrays to a sharded function.Thanks for any tips!
Upvotes: 1
Views: 369
Reputation: 86443
When you wrap a function in jit
, it's like creating a self-contained kernel that will be executed on the device. There's no mechanism for detecting and cacheing reused computations across different kernel calls, so the answer to your question (1) is no.
That said, if you wrap multiple calls to your function in another function that is wrapped in jit
, the compiler will in general do this sort of de-duplication: if certain values are computed multiple times in your Python implementation, they may be de-duplicated within that outer JIT. If you want to see which operations are fused/de-duplicated by the compiler, one way to get insight into this is using Ahead-of-time Compilation tools to print the compiled HLO.
Regarding your question 2: I don't think there's enough information here to answer for sure, but I suspect your best bet would be to shard the pre-computed values appropriately using in_shardings
when you pass them to the function.
Upvotes: 1