Igor Rivin
Igor Rivin

Reputation: 4864

jax confusion - what are the best practices?

Consider the following function (stolen from scipy), it produces a random orthogonal matrix of dimension dim:

def nrvs(dim=3):
     random_state = np.random
     H = np.eye(dim)
     D = np.ones((dim,))
     for n in range(1, dim):
         x = random_state.normal(size=(dim-n+1,))
         D[n-1] = np.sign(x[0])
         x[0] -= D[n-1]*np.sqrt((x*x).sum())
         # Householder transformation                                                                                                                                          
         Hx = (np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum())
         mat = np.eye(dim)
         mat[n-1:, n-1:] = Hx
         H = np.dot(H, mat)
         # Fix the last sign such that the determinant is 1                                                                                                                    
      D[-1] = (-1)**(1-(dim % 2))*D.prod()
     # Equivalent to np.dot(np.diag(D), H) but faster, apparently                                                                                                              
     H = (D*H.T).T
     return H

Fool that I was, I thought that one could port this into jax by just replacing np with jnp. This, however seems to be far from true. The first problem is that jax has a completely different view of random number generation, you pass a key in, and then you "split" that key. Since this function would be used to produce a random SEQUENCE of matrices, it seems that the only reasonable approach is to make this thing into a closure (as below).

A different and even more annoying problem is that something like D[n-1] = np.sign([x[0]) does NOT work in jax because jax arrays are immutable, so the documentation suggests D.at[n-1].set(np.sign[x[0]]). However, that does not work ether, because the result is a new array which equals the desired modified D. So, the correct syntax appears to be: D=D.at[n-1].set(np.sign[x[0]]). Is this really so? It looks really gross. The final result was this:

def rvs(dim=3):
   key = jax.random.PRNGKey(42)
   def rvs_aux():
       nonlocal key
       H = jnp.eye(dim)
       D = jnp.ones((dim,))
       for n in range(1, dim):
           key, subkey = jax.random.split(key)
           x = jax.random.normal(subkey, shape=(dim-n+1,))
           D=(D.at[n-1].set(jnp.sign(x[0])))                                                                                                                                        
           x=(x.at[0].set(x[0] - D[n-1]*jnp.sqrt((x*x).sum())))
           # Householder transformation                                                                                                                                                                                                                                                                                          
           Hx = (jnp.eye(dim-n+1) - 2.*jnp.outer(x, x)/(x*x).sum())
           mat = jnp.eye(dim)                                                                                                                                                                  
           mat=(mat.at[n-1:, n-1:].set(Hx))
           H = jnp.dot(H, mat)
      # Fix the last sign such that the determinant is 1                                                                                                                                                                                                                                                                                                                                                                                                                                                        
      D=(D.at[-1].set((-1)**(1-(dim % 2))*D.prod()))
      # Equivalent to np.dot(np.diag(D), H) but faster, apparently                                                                                                                                                                                                                                                                          
      H = (D*H.T).T
      return H
  return jax.jit(rvs_aux)

Is this how such things should be written? Any words of wisdom appreciated.

Upvotes: 2

Views: 1842

Answers (1)

jakevdp
jakevdp

Reputation: 86443

Fool that I was, I thought that one could port this into jax by just replacing np with jnp. This, however seems to be far from true.

Yeah, random numbers and in-place updates are two of the main incompatibilities mentioned in the Sharp Bits docs.

So, the correct syntax appears to be: D = D.at[n-1].set(np.sign[x[0]]). Is this really so?

Yes, this is the best method to use instead of in-place updates. As for its grossness, I suppose it's a matter of preference. I really like JAX's functional update syntax, but I understand if you disagree.

Looking at your final function, your approach is close to what I'd call "best practice", but I'd simplify it a bit by wrapping the main function in jit with static_argnames, and also by making the random key an explicit argument:

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.jit, static_argnames=['dim'])
def rvs(key, dim=3):
    H = jnp.eye(dim)
    D = jnp.ones((dim,))
    for n in range(1, dim):
        key, subkey = jax.random.split(key)
        x = jax.random.normal(subkey, shape=(dim-n+1,))
        D = D.at[n-1].set(jnp.sign(x[0]))
        x = x.at[0].set(x[0] - D[n-1]*jnp.sqrt((x*x).sum()))
        # Householder transformation
        Hx = (jnp.eye(dim-n+1) - 2.*jnp.outer(x, x)/(x*x).sum())
        mat = jnp.eye(dim).at[n-1:, n-1:].set(Hx)
        H = jnp.dot(H, mat)
    # Fix the last sign such that the determinant is 1
    D = D.at[-1].set((-1)**(1-(dim % 2))*D.prod())
    # Equivalent to np.dot(np.diag(D), H) but faster, apparently
    H = (D*H.T).T
    return H

key=jax.random.PRNGKey(42)
rvs(key, 3)

Note that I made the random key an explicit argument to the function; you'll find that this is the pattern used for all functionality in jax.random. The reason is that JAX is a functional language, and so there is no global random state, and so best practice is to pass keys explicitly (see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers for more discussion of this).

Upvotes: 4

Related Questions