bachr
bachr

Reputation: 6006

How to batch Jax array and vmap

I've a function that works on batches of an array defined like this

def batched_fn(X):
  @jax.jit
  def apply(Xb):
    Xb_out = ...
    return Xb_out
  return apply

The apply function will use X and Xb to calculate Xb_out and can be called on a batch like this:

n = X.shape[0]
batches = []
batch_apply = batched_fn(X)
for i in range(0, n, batch_size):
  s = slice(i, min(i+batch_size, n))
  Xb = batch_apply(X[s])
  batches.append(Xb)
X_out = jnp.concatenate(batches, axis=0)

I tried to rewrite the above using jax.vmap like this

func = batched_fn(X)
X_out = jax.vmap(func)(X)

This seems to call func with only one row and not a batch of rows!

What is the proper way to batch a jax array?

Upvotes: -1

Views: 1518

Answers (1)

jakevdp
jakevdp

Reputation: 86300

It sounds like this is working as expected: vmap is not a batching transform in the way you're thinking about it, but rather a vectorizing transform that is equivalent to calling a function one row at a time.

Upvotes: 1

Related Questions