Reputation: 6006
I've a function that works on batches of an array defined like this
def batched_fn(X):
@jax.jit
def apply(Xb):
Xb_out = ...
return Xb_out
return apply
The apply function will use X and Xb to calculate Xb_out and can be called on a batch like this:
n = X.shape[0]
batches = []
batch_apply = batched_fn(X)
for i in range(0, n, batch_size):
s = slice(i, min(i+batch_size, n))
Xb = batch_apply(X[s])
batches.append(Xb)
X_out = jnp.concatenate(batches, axis=0)
I tried to rewrite the above using jax.vmap
like this
func = batched_fn(X)
X_out = jax.vmap(func)(X)
This seems to call func
with only one row and not a batch of rows!
What is the proper way to batch a jax array?
Upvotes: -1
Views: 1518
Reputation: 86300
It sounds like this is working as expected: vmap
is not a batching transform in the way you're thinking about it, but rather a vectorizing transform that is equivalent to calling a function one row at a time.
Upvotes: 1