Reputation: 4864
Consider the following function (stolen from scipy
), it produces a random orthogonal matrix of dimension dim
:
def nrvs(dim=3):
random_state = np.random
H = np.eye(dim)
D = np.ones((dim,))
for n in range(1, dim):
x = random_state.normal(size=(dim-n+1,))
D[n-1] = np.sign(x[0])
x[0] -= D[n-1]*np.sqrt((x*x).sum())
# Householder transformation
Hx = (np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum())
mat = np.eye(dim)
mat[n-1:, n-1:] = Hx
H = np.dot(H, mat)
# Fix the last sign such that the determinant is 1
D[-1] = (-1)**(1-(dim % 2))*D.prod()
# Equivalent to np.dot(np.diag(D), H) but faster, apparently
H = (D*H.T).T
return H
Fool that I was, I thought that one could port this into jax
by just replacing np
with jnp
. This, however seems to be far from true. The first problem is that jax
has a completely different view of random number generation, you pass a key in, and then you "split" that key. Since this function would be used to produce a random SEQUENCE of matrices, it seems that the only reasonable approach is to make this thing into a closure (as below).
A different and even more annoying problem is that something like D[n-1] = np.sign([x[0])
does NOT work in jax
because jax arrays are immutable, so the documentation suggests D.at[n-1].set(np.sign[x[0]]).
However, that does not work ether, because the result is a new array which equals the desired modified D.
So, the correct syntax appears to be:
D=D.at[n-1].set(np.sign[x[0]]).
Is this really so? It looks really gross. The final result was this:
def rvs(dim=3):
key = jax.random.PRNGKey(42)
def rvs_aux():
nonlocal key
H = jnp.eye(dim)
D = jnp.ones((dim,))
for n in range(1, dim):
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(dim-n+1,))
D=(D.at[n-1].set(jnp.sign(x[0])))
x=(x.at[0].set(x[0] - D[n-1]*jnp.sqrt((x*x).sum())))
# Householder transformation
Hx = (jnp.eye(dim-n+1) - 2.*jnp.outer(x, x)/(x*x).sum())
mat = jnp.eye(dim)
mat=(mat.at[n-1:, n-1:].set(Hx))
H = jnp.dot(H, mat)
# Fix the last sign such that the determinant is 1
D=(D.at[-1].set((-1)**(1-(dim % 2))*D.prod()))
# Equivalent to np.dot(np.diag(D), H) but faster, apparently
H = (D*H.T).T
return H
return jax.jit(rvs_aux)
Is this how such things should be written? Any words of wisdom appreciated.
Upvotes: 2
Views: 1842
Reputation: 86443
Fool that I was, I thought that one could port this into jax by just replacing
np
withjnp
. This, however seems to be far from true.
Yeah, random numbers and in-place updates are two of the main incompatibilities mentioned in the Sharp Bits docs.
So, the correct syntax appears to be:
D = D.at[n-1].set(np.sign[x[0]])
. Is this really so?
Yes, this is the best method to use instead of in-place updates. As for its grossness, I suppose it's a matter of preference. I really like JAX's functional update syntax, but I understand if you disagree.
Looking at your final function, your approach is close to what I'd call "best practice", but I'd simplify it a bit by wrapping the main function in jit
with static_argnames
, and also by making the random key an explicit argument:
from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.jit, static_argnames=['dim'])
def rvs(key, dim=3):
H = jnp.eye(dim)
D = jnp.ones((dim,))
for n in range(1, dim):
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(dim-n+1,))
D = D.at[n-1].set(jnp.sign(x[0]))
x = x.at[0].set(x[0] - D[n-1]*jnp.sqrt((x*x).sum()))
# Householder transformation
Hx = (jnp.eye(dim-n+1) - 2.*jnp.outer(x, x)/(x*x).sum())
mat = jnp.eye(dim).at[n-1:, n-1:].set(Hx)
H = jnp.dot(H, mat)
# Fix the last sign such that the determinant is 1
D = D.at[-1].set((-1)**(1-(dim % 2))*D.prod())
# Equivalent to np.dot(np.diag(D), H) but faster, apparently
H = (D*H.T).T
return H
key=jax.random.PRNGKey(42)
rvs(key, 3)
Note that I made the random key an explicit argument to the function; you'll find that this is the pattern used for all functionality in jax.random
. The reason is that JAX is a functional language, and so there is no global random state, and so best practice is to pass keys explicitly (see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers for more discussion of this).
Upvotes: 4