Nick Brady
Nick Brady

Reputation: 151

How to iterate over tuples using jax.lax.scan

I am looking to translate a bit of code from a NumPy version listed here, to a JAX compatible version. The NumPy code iteratively calculates the value of a matrix, E from the values of other matrices, A, B, D, as well as the value of E from the previous iteration: E_jm1.

Both the NumPy and JAX version work in their listed forms and produce identical results. How can I get the JAX version to work when passing A, B, D as a tuple instead of as a concatenated array? I have a specific use case where a tuple would be more useful.

I found a question asking something similar, but it just confirmed that this should be possible. There are no examples in the documentation or elsewhere that I could find.

Original NumPy version

import numpy as np
import jax
import jax.numpy as jnp

def BAND_J(A, B, D, E_jm1):
    '''
        output: E(N x N)
        input: A(N x N), B(N x N), D(N x N), E_jm1(N x N)
        ๐„โฑผ = -[๐ + ๐€๐„โฑผโ‚‹โ‚]โปยน ๐ƒ
    '''

    B_inv = np.linalg.inv(B + np.dot( A, E_jm1 ))

    E  = -np.dot(B_inv, D)

    return E

key = jax.random.PRNGKey(0)

N  = 2
NJ = 4

# initialize matrices with random values
A, B, D = [ jax.random.normal(key, shape=(N,N,NJ)),
            jax.random.normal(key, shape=(N,N,NJ)),
            jax.random.normal(key, shape=(N,N,NJ)) ]
A_np, B_np, D_np = [np.asarray(A), np.asarray(B), np.asarray(D)]

# initialize E_0
E_0 = jax.random.normal(key+2, shape=(N,N))

E_np        = np.empty((N,N,NJ))
E_np[:,:,0] = np.asarray(E_0)

# iteratively calculate E from A, B, D, and ๐„โฑผโ‚‹โ‚
for j in range(1,NJ):
    E_jm1       = E_np[:,:,j-1]
    E_np[:,:,j] = BAND_J(A_np[:,:,j], B_np[:,:,j], D_np[:,:,j], E_jm1)

JAX scan version

def BAND_J(E, ABD):
    '''
        output: E(N x N)
        input: A(N x N), B(N x N), D(N x N), E_jm1(N x N)
    '''

    A, B, D = ABD

    B_inv = jnp.linalg.inv(B + jnp.dot( A, E ))

    E  = -jnp.dot(B_inv, D)

    return E, E # ("carryover", "accumulated")

abd = jnp.asarray([(A[:,:,j], B[:,:,j], D[:,:,j]) for j in range(NJ)])
# abd = tuple([(A[:,:,j], B[:,:,j], D[:,:,j]) for j in range(NJ)]) # this produces error
# ValueError: too many values to unpack (expected 3)

_, E = lax.scan(BAND_J, E_0, abd)

for j in range(1, NJ):
    print(np.isclose(E[j-1], E_np[:,:,j]))

Upvotes: 2

Views: 2712

Answers (2)

davidec00
davidec00

Reputation: 53

In general, scan can iterate tuples in this way:

arr1 = jnp.arange(10)
arr2 = jnp.arange(10)

_, E = jax.lax.scan(fun, E_0, (arr1,arr2)) 
# fun will receive (0,0) then (1,1) ... (10,10)

For your code:

_, E = lax.scan(BAND_J, E_0, (A.T,B.T,D.T))

Upvotes: 0

jakevdp
jakevdp

Reputation: 86443

The short answer is "you can't". By design, jax.scan can scan over axes of arrays, not entries of arbitrary Python collections.

So if you want to use scan, you'll have to stack your entires into an array.

That said, since your tuple only has three elements, a good alternative would be to skip the scan and simply JIT-compile the for loop approach. JAX tracing will effectively unroll the loop and optimize the flattened sequence of operations. While this can lead to long compile times for large loops, since your application is only 3 iterations it shouldn't be problematic.

Upvotes: 2

Related Questions