Reputation: 612
I'm wondering if there is a good way to limit the memory usage for Jax's VMAP function? Equivalently, to vmap in batches at a time if that makes sense?
In my specific use case, I have a set of images and I'd like to calculate the affinity between each pair of images; so ~order((num_imgs)^2 * (img shape)) bytes of memory used all at once if I'm understanding vmap correctly (which gets huge since in my real example I have 10,000 100x100 images).
A basic example is:
def affininty_matrix_ex(n_arrays=10, img_size=5, key=jax.random.PRNGKey(0), gamma=jnp.array([0.5])):
arr_of_imgs = jax.random.normal(jax.random.PRNGKey(0), (n_arrays, img_size, img_size))
arr_of_indices = jnp.arange(n_arrays)
inds_1, inds_2 = zip(*combinations(arr_of_indices, 2))
v_cPA = jax.vmap(calcPairAffinity2, (0, 0, None, None), 0)
affinities = v_cPA(jnp.array(inds_1), jnp.array(inds_2), arr_of_imgs, gamma)
print()
print(jax.make_jaxpr(v_cPA)(jnp.array(inds_1), jnp.array(inds_2), arr_of_imgs, gamma))
affinities = affinities.reshape(-1)
arr = jnp.zeros((n_arrays, n_arrays), dtype=jnp.float16)
arr = arr.at[jnp.triu_indices(arr.shape[0], k=1)].set(affinities)
arr = arr + arr.T
arr = arr + jnp.identity(n_arrays, dtype=jnp.float16)
return arr
def calcPairAffinity2(ind1, ind2, imgs, gamma):
#Returns a jnp array of 1 float, jnp.sum adds all elements together
image1, image2 = imgs[ind1], imgs[ind2]
diff = jnp.sum(jnp.abs(image1 - image2))
normed_diff = diff / image1.size
val = jnp.exp(-gamma*normed_diff)
val = val.astype(jnp.float16)
return val
I suppose I could just say something like "only feed into vmap X pairs at a time, and loop through n_chunks = n_arrays/X, appending each groups results to a list" but that doesn't seem to be ideal. My understanding is vmap does not like generators, not sure if that would be an alternative way around the issue.
Upvotes: 4
Views: 580
Reputation: 86443
Edit, Aug 13 2024
As of JAX version 0.4.31, what you're asking for is possible using the batch_size
argument of lax.map
.
For an iterable of size N
, this will perform a scan with N // batch_size
steps, and within each step will vmap
the function over the batch. lax.map
has less flexible semantics than jax.vmap
, but for the simplest cases they look relatively similar. Here's an example using your calcPairAffinity
function:
For example
import jax
import jax.numpy as jnp
def calcPairAffinity(ind1, ind2, imgs, gamma=0.5):
image1, image2 = imgs[ind1], imgs[ind2]
diff = jnp.sum(jnp.abs(image1 - image2))
normed_diff = diff / image1.size
val = jnp.exp(-gamma*normed_diff)
val = val.astype(jnp.float16)
return val
imgs = jax.random.normal(jax.random.key(0), (100, 5, 5))
inds = jnp.arange(imgs.shape[0])
inds1, inds2 = map(jnp.ravel, jnp.meshgrid(inds, inds))
def f(inds):
return calcPairAffinity(*inds, imgs, 0.5)
result_vmap = jax.vmap(f)((inds1, inds2))
result_batched = jax.lax.map(f, (inds1, inds2), batch_size=1000)
assert jnp.allclose(result_vmap, result_batched)
Original answer
This is a frequent request, but unfortunately there's not yet (as of JAX version 0.4.20) any built-in utility to do chunked/batched vmap (xmap
does have some functionality along these lines, but is experimental/incomplete and I wouldn't recommend relying on it).
Adding chunking to vmap
is tracked in https://github.com/google/jax/issues/11319, and there's some code there that does a limited version of what you have in mind. Hopefully something like what you describe will be possible with JAX's built-in vmap
soon. In the meantime, you might think about applying vmap
to chunks manually in the way you describe in your question.
Upvotes: 3