Reputation: 431
I was wondering how we could use jax (https://github.com/google/jax) to compute a mapping of the derivative.
That is to say :
we have a vector and we want to apply (with the jax framework) a function to it, we call it
and it's a function
My question is : how can we easily retrieve the vector :
For exemple :
from jax import random
from jax import jacfwd, jacrev
import jax.numpy as jnp
key = random.PRNGKey(0)
key, W_key, b_key, input_key = random.split(key, 4)
W = random.normal(W_key, (10, 10))
b = random.normal(b_key, (10, ))
input = random.normal(input_key, (10, ))
One easy way to do that will be to take diagonal of the jacobian, but this method is very slow for high dimensional vector (> 10000). I am only interested in the diagonal of the jacobian ...
def f(input):
return jnp.dot(W, input) + b
J = jacfwd(f, argnums=0)(input)
result = jnp.diagonal(J)
For recall the jabobian matrix is :
Upvotes: 1
Views: 1022
Reputation: 86320
There's not really a natural way to do this with JAX's transforms: you cannot simply map the input, because in general each diagonal entry of the jacobian depends on all inputs.
But given your particular function, you could compute the diagonal of the jacobian directly by rewriting the function like this:
from jax import vmap, grad
def f_single(val, i, W=W, b=b, input=input):
return jnp.dot(W[i], input.at[i].set(val)) + b[i]
idx = jnp.arange(len(input))
# equivalent to f(input)
print(vmap(f_single)(input, idx))
# [-1.5965443 -1.4081277 1.866176 -0.9789318 2.6717818 -1.0995009
# -2.3647223 3.6962256 3.3946664 2.589026 ]
# equivalent to jnp.diagonal(jacrev(f)(input))
print(vmap(grad(f_single))(input, idx))
# [-0.87553114 0.543098 2.265052 0.1403018 -1.4744948 1.4401387
# 0.4466088 0.72063404 -0.9135868 0.34965768]
Upvotes: 2