Reputation: 612
I'm trying to vmap a function. My understanding of vmap is essentially anywhere I would write a ~for loop/list comprehension I should instead consider vmapping. I have a few points of confusion:
The reason I'm asking is that it seems like vmap, like JIT, runs into all sorts of ConcretizationTypeError and seems (not 100% clear yet) to need constant sized items for everything. I associate this behavior with any function I'm trying to Jit, but not necessarily any function I write in Jax.
Upvotes: 2
Views: 1290
Reputation: 86443
Does vmap need fixed sizes for everything through the function(s) being vmapped?
yes – vmap
, like all JAX transformations, requires any arrays defined in the function to have static shapes.
Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).
No, vmap
does not jit
-compile a function by default, although you can always compose both if you wish (e.g. jit(vmap(f))
)
If vmap is jit-ing something, how would one use something like a static-arguments with vmap?
As mentioned, vmap
is unrelated to jit
, but an analogy of jit
static_argnums
is passing None
to in_axes
, which will keep the argument unmapped and therefore static within the transformation.
What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)
Upvotes: 3
Reputation: 612
A section of my code now looks like:
vmaped_f = jax.vmap(my_func, parallel_axes, 0)
n_batches = int(num_items / batch_size)
if num_items % batch_size != 0:
n_batches += 1 #Round up
all_vals = []
for i in range(n_batches):
top = min([num_items, (i+1)*batch_size])
batch_inds = jnp.arange(i*batch_size, top)
batch_inds_1, batch_inds_2 = jnp.array(inds_1)[batch_inds], \
jnp.array(inds_2)[batch_inds]
f_vals = vmaped_f(batch_inds_1, batch_inds2, other_relevant_inputs)
all_vals.extend(f_vals.tolist())
The vmap'd function basically takes in all of my data, and the indices of that data to use (which will be constant sized except for potentially the last batch, so only need to jit compile 2x if you'd want to jit it).
Upvotes: 0