Adrien Forbu
Adrien Forbu

Reputation: 431

Jacobian diagonal computation in JAX

I was wondering how we could use jax (https://github.com/google/jax) to compute a mapping of the derivative.

That is to say : we have a vector and we want to apply (with the jax framework) a function to it, we call it and it's a function

My question is : how can we easily retrieve the vector :

For exemple :

from jax import random
from jax import jacfwd, jacrev
import jax.numpy as jnp

key = random.PRNGKey(0)

key, W_key, b_key, input_key = random.split(key, 4)
W = random.normal(W_key, (10, 10))
b = random.normal(b_key, (10, ))
input = random.normal(input_key, (10, ))

One easy way to do that will be to take diagonal of the jacobian, but this method is very slow for high dimensional vector (> 10000). I am only interested in the diagonal of the jacobian ...

def f(input):
  return jnp.dot(W, input) + b

J = jacfwd(f, argnums=0)(input)

result = jnp.diagonal(J)

For recall the jabobian matrix is :

Upvotes: 1

Views: 1022

Answers (1)

jakevdp
jakevdp

Reputation: 86320

There's not really a natural way to do this with JAX's transforms: you cannot simply map the input, because in general each diagonal entry of the jacobian depends on all inputs.

But given your particular function, you could compute the diagonal of the jacobian directly by rewriting the function like this:

from jax import vmap, grad

def f_single(val, i, W=W, b=b, input=input):
  return jnp.dot(W[i], input.at[i].set(val)) + b[i]

idx = jnp.arange(len(input))

# equivalent to f(input)
print(vmap(f_single)(input, idx))
# [-1.5965443 -1.4081277  1.866176  -0.9789318  2.6717818 -1.0995009
#  -2.3647223  3.6962256  3.3946664  2.589026 ]

# equivalent to jnp.diagonal(jacrev(f)(input))
print(vmap(grad(f_single))(input, idx))
# [-0.87553114  0.543098    2.265052    0.1403018  -1.4744948   1.4401387
#   0.4466088   0.72063404 -0.9135868   0.34965768]

Upvotes: 2

Related Questions