Milad
Milad

Reputation: 5510

Efficient way to compute Jacobian x Jacobian.T

Assume J is the Jacobian of some function f with respect to some parameters. Are there efficient ways (in PyTorch or perhaps Jax) to have a function that takes two inputs (x1 and x2) and computes J(x1)*J(x2).transpose() without instantiating the entire J matrices in memory?

I have come across something like jvp(f, input, v=vjp(f, input)) but don't quite understand it and not sure is what I want.

Upvotes: 4

Views: 1982

Answers (1)

jakevdp
jakevdp

Reputation: 86320

In JAX, you can compute a full jacobian matrix using jax.jacfwd or jax.jacrev, or you can compute a jacobian operator and its transpose using jax.jvp and jax.vjp.

So, for example, say you had a function Rᴺ → Rᴹ that looks something like this:

import jax.numpy as jnp
import numpy as np

np.random.seed(1701)
N, M = 10000, 5
f_mat = np.array(np.random.rand(M, N))
def f(x):
  return jnp.sqrt(f_mat @ x / N)

Given two vectors x1 and x2, you can evaluate the Jacobian matrix at each using jax.jacfwd

import jax
x1 = np.array(np.random.rand(N))
x2 = np.array(np.random.rand(N))
J1 = jax.jacfwd(f)(x1)
J2 = jax.jacfwd(f)(x2)
print(J1 @ J2.T)
# [[3.3123782e-05 2.5001222e-05 2.4946943e-05 2.5180108e-05 2.4940484e-05]
#  [2.5084497e-05 3.3233835e-05 2.4956826e-05 2.5108084e-05 2.5048916e-05]
#  [2.4969209e-05 2.4896170e-05 3.3232871e-05 2.5006309e-05 2.4947023e-05]
#  [2.5102483e-05 2.4947576e-05 2.4906987e-05 3.3327218e-05 2.4958186e-05]
#  [2.4981882e-05 2.5007204e-05 2.4966144e-05 2.5076926e-05 3.3595043e-05]]

But, as you note, along the way to computing this 5x5 result, we instantiate two 5x10,000 matrices. How might we get around this?

The answer is in jax.jvp and jax.vjp. These have somewhat unintuitive call signatures for the purposes of your question, as they are designed primarily for use in forward-mode and reverse-mode automatic differentiation. But broadly, you can think of them as a way to compute J @ v and J.T @ v for a vector v, without having to actually compute J explicitly.

For example, you can use jax.jvp to compute the effect of J1 operating on a vector, without actually computing J1:

J1_op = lambda v: jax.jvp(f, (x1,), (v,))[1]

vN = np.random.rand(N)
np.allclose(J1 @ vN, J1_op(vN))
# True

Similarly, you can use jax.vjp to compute the effect of J2.T operating on a vector, without actually computing J2:

J2T_op = lambda v: jax.vjp(f, x2)[1](v)[0]

vM = np.random.rand(M)
np.allclose(J2.T @ vM, J2T_op(vM))
# True

Putting these together and operating on an identity matrix gives you the full jacobian matrix product that you're after:

def direct(f, x1, x2):
  J1 = jax.jacfwd(f)(x1)
  J2 = jax.jacfwd(f)(x2)
  return J1 @ J2.T

def indirect(f, x1, x2, M):
  J1J2T_op = lambda v: jax.jvp(f, (x1,), jax.vjp(f, x2)[1](v))[1]
  return jax.vmap(J1J2T_op)(jnp.eye(M)).T

np.allclose(direct(f, x1, x2), indirect(f, x1, x2, M))
# True

Along with the memory savings, this indirect method is also a fair bit faster than the direct method, depending on the sizes of the jacobians involved:

%time direct(f, x1, x2)
# CPU times: user 1.43 s, sys: 14.9 ms, total: 1.44 s
# Wall time: 886 ms
%time indirect(f, x1, x2, M)
# CPU times: user 311 ms, sys: 0 ns, total: 311 ms
# Wall time: 158 ms

Upvotes: 6

Related Questions