Reputation: 151
I am looking to translate a bit of code from a NumPy version listed here, to a JAX compatible version. The NumPy code iteratively calculates the value of a matrix, E from the values of other matrices, A, B, D, as well as the value of E from the previous iteration: E_jm1.
Both the NumPy and JAX version work in their listed forms and produce identical results. How can I get the JAX version to work when passing A
, B
, D
as a tuple instead of as a concatenated array? I have a specific use case where a tuple would be more useful.
I found a question asking something similar, but it just confirmed that this should be possible. There are no examples in the documentation or elsewhere that I could find.
Original NumPy version
import numpy as np
import jax
import jax.numpy as jnp
def BAND_J(A, B, D, E_jm1):
'''
output: E(N x N)
input: A(N x N), B(N x N), D(N x N), E_jm1(N x N)
๐โฑผ = -[๐ + ๐๐โฑผโโ]โปยน ๐
'''
B_inv = np.linalg.inv(B + np.dot( A, E_jm1 ))
E = -np.dot(B_inv, D)
return E
key = jax.random.PRNGKey(0)
N = 2
NJ = 4
# initialize matrices with random values
A, B, D = [ jax.random.normal(key, shape=(N,N,NJ)),
jax.random.normal(key, shape=(N,N,NJ)),
jax.random.normal(key, shape=(N,N,NJ)) ]
A_np, B_np, D_np = [np.asarray(A), np.asarray(B), np.asarray(D)]
# initialize E_0
E_0 = jax.random.normal(key+2, shape=(N,N))
E_np = np.empty((N,N,NJ))
E_np[:,:,0] = np.asarray(E_0)
# iteratively calculate E from A, B, D, and ๐โฑผโโ
for j in range(1,NJ):
E_jm1 = E_np[:,:,j-1]
E_np[:,:,j] = BAND_J(A_np[:,:,j], B_np[:,:,j], D_np[:,:,j], E_jm1)
JAX scan version
def BAND_J(E, ABD):
'''
output: E(N x N)
input: A(N x N), B(N x N), D(N x N), E_jm1(N x N)
'''
A, B, D = ABD
B_inv = jnp.linalg.inv(B + jnp.dot( A, E ))
E = -jnp.dot(B_inv, D)
return E, E # ("carryover", "accumulated")
abd = jnp.asarray([(A[:,:,j], B[:,:,j], D[:,:,j]) for j in range(NJ)])
# abd = tuple([(A[:,:,j], B[:,:,j], D[:,:,j]) for j in range(NJ)]) # this produces error
# ValueError: too many values to unpack (expected 3)
_, E = lax.scan(BAND_J, E_0, abd)
for j in range(1, NJ):
print(np.isclose(E[j-1], E_np[:,:,j]))
Upvotes: 2
Views: 2712
Reputation: 53
In general, scan can iterate tuples in this way:
arr1 = jnp.arange(10)
arr2 = jnp.arange(10)
_, E = jax.lax.scan(fun, E_0, (arr1,arr2))
# fun will receive (0,0) then (1,1) ... (10,10)
For your code:
_, E = lax.scan(BAND_J, E_0, (A.T,B.T,D.T))
Upvotes: 0
Reputation: 86443
The short answer is "you can't". By design, jax.scan
can scan over axes of arrays, not entries of arbitrary Python collections.
So if you want to use scan
, you'll have to stack your entires into an array.
That said, since your tuple only has three elements, a good alternative would be to skip the scan and simply JIT-compile the for
loop approach. JAX tracing will effectively unroll the loop and optimize the flattened sequence of operations. While this can lead to long compile times for large loops, since your application is only 3 iterations it shouldn't be problematic.
Upvotes: 2