Igor Rivin
Igor Rivin

Reputation: 4864

Vectorization guidelnes for jax

suppose I have a function (for simplicity, covariance between two series, though the question is more general):

def cov(x, y):
   return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

Now I have a "dataframe" D (a 2-dimenisonal array, whose columns are my series) and I want to vectorize cov in such a way that the application of the vectorized function produces the covariance matrix. Now, there is an obvious way of doing it:

cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))

but that seems a little clunky. Is there a "canonical" way of doing this?

Upvotes: 1

Views: 459

Answers (1)

jakevdp
jakevdp

Reputation: 86310

If you want to express logic equivalent to nested for loops with vmap, then yes it requires nested vmaps. I think what you've written is probably as canonical as you can get for an operation like this, although it might be slightly more clear if written using decorators:

from functools import partial

@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
   return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

For this particular function, though, note that you can express the same thing using a single dot product if you wish:

result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))

Upvotes: 1

Related Questions