Evan Mata
Evan Mata

Reputation: 612

JAX vmap JIT behind the scenes?

I'm trying to vmap a function. My understanding of vmap is essentially anywhere I would write a ~for loop/list comprehension I should instead consider vmapping. I have a few points of confusion:

  1. Does vmap need fixed sizes for everything through the function(s) being vmapped?
  2. Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).
  3. If vmap is jit-ing something, how would one use something like a static-arguments with vmap?
  4. What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)

The reason I'm asking is that it seems like vmap, like JIT, runs into all sorts of ConcretizationTypeError and seems (not 100% clear yet) to need constant sized items for everything. I associate this behavior with any function I'm trying to Jit, but not necessarily any function I write in Jax.

Upvotes: 2

Views: 1290

Answers (2)

jakevdp
jakevdp

Reputation: 86443

Does vmap need fixed sizes for everything through the function(s) being vmapped?

yes – vmap, like all JAX transformations, requires any arrays defined in the function to have static shapes.

Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).

No, vmap does not jit-compile a function by default, although you can always compose both if you wish (e.g. jit(vmap(f)))

If vmap is jit-ing something, how would one use something like a static-arguments with vmap?

As mentioned, vmap is unrelated to jit, but an analogy of jit static_argnums is passing None to in_axes, which will keep the argument unmapped and therefore static within the transformation.

What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)

Upvotes: 3

Evan Mata
Evan Mata

Reputation: 612

A section of my code now looks like:

vmaped_f = jax.vmap(my_func, parallel_axes, 0)
n_batches = int(num_items / batch_size)
if num_items % batch_size != 0:
    n_batches += 1 #Round up
    
all_vals = []
for i in range(n_batches):
    top = min([num_items, (i+1)*batch_size])
    batch_inds = jnp.arange(i*batch_size, top)
    batch_inds_1, batch_inds_2 = jnp.array(inds_1)[batch_inds], \
                                 jnp.array(inds_2)[batch_inds]
    f_vals = vmaped_f(batch_inds_1, batch_inds2, other_relevant_inputs)
    all_vals.extend(f_vals.tolist())

The vmap'd function basically takes in all of my data, and the indices of that data to use (which will be constant sized except for potentially the last batch, so only need to jit compile 2x if you'd want to jit it).

Upvotes: 0

Related Questions