Reputation: 418
What is the recommended approach to implement array behaviour/methods on irregular/inhomogeneous data (possesses some inherient dimensionality) within JAX?
Two principle options come to mind:
Clearly option 1 is favourable as this requires less implementation overhead (and consequently validation/testing). The concern is memory complexity - in situations where this is paramount (to avoid having to distribute the array) is there a better alternative to option 2 (that can exploit the highly optimised array methods)?
EDIT: The following implements a concrete example which contains sparsity.
import jax as jx
import jax.numpy as jnp
jx.config.update("jax_enable_x64", True)
# Problem specific variables (static)
n_vars = 3 # Number of variable sets
n_smps = 10 # Maximum number of set elements
p_smps = 0.2 # Representation of problem sparsity
# Each set contains a differing number of elements (binomial random for example)
n_lvls = jx.random.bernoulli(
jx.random.PRNGKey(0),
p_smps,
(n_vars, n_smps)
).sum(axis=1, dtype='i4')
# Derived quantities depend on constant coefficients (uniform random for example)
a_vars = jx.random.uniform(jx.random.PRNGKey(1), (n_vars, ), dtype='f8')
b_vars = jx.random.uniform(jx.random.PRNGKey(2), (n_vars, ), dtype='f8')
b_vars = 10.0*b_vars
c_vars = jx.random.uniform(jx.random.PRNGKey(3), (n_vars, ), dtype='f8')
c_vars = 2.0*c_vars
The problem is intrinsically represented with a 7 element state. What follows is one implementation of option 1
### Homogeneous with mask ###
# Define the level index array
i_smps = jnp.arange(n_smps, dtype='i4')
mask = n_lvls[:,None]>i_smps[None,:]
# Generate an initial state that respects the unity axiom
x_vars = 1.0/(1.0+n_lvls[:,None]*i_smps[None,:]).astype('f8')
x_vars = jnp.where(mask, x_vars, 0.0)
x_vars = x_vars/x_vars.sum()
# Generate a coefficient tensor
P_vars = a_vars[:,None]+b_vars[:,None]*i_smps[None,:]
P_vars = jnp.where(mask, P_vars, 0.0)
# Determine a scalar moment
scalar_moment = (x_vars*c_vars[:,None]).sum()
# >>> DeviceArray(0.66574861, dtype=float64)
# Determine a transition tensor
trans_tens = (P_vars[:,:,None,None]-P_vars[None,None,:,:])
trans_tens = trans_tens*x_vars[None,None,:,:]*x_vars[:,:,None,None]
trans_tens.sum(axis=(2,3))
# >>> DeviceArray([[-0.37032842, 0.16153429, 0.22063015, 0.24335933, ...
Ensuring homogeneity increases this to 30. Furthermore, computing derived quantities involves numerous multiply by zero operations.
Upvotes: 0
Views: 320
Reputation: 418
This is one approach to flatten and reduce sparsity using Boolean indexing.
# Required for static Boolean indexing
import numpy as np
n_lvls = np.asarray(n_lvls, dtype='i4')
### Flatten with Boolean Indexing ###
i_grid = np.arange(n_lvls.size, dtype='i4')
j_grid = np.arange(n_lvls.max(), dtype='i4')
# Determine the boolean mask
mask = n_lvls[:,None]>j_grid[None,:]
bc_size = (i_grid.size, j_grid.size)
i = np.broadcast_to(i_grid[:,None], bc_size)[mask]
# >>> array([0, 0, 0, 0, 1, 1, 2], dtype=int32)
j = np.broadcast_to(j_grid[None,:], bc_size)[mask]
# >>> array([0, 1, 2, 3, 0, 1, 0], dtype=int32)
# Generate an initial state that respects the unity axiom
x = 1.0/(1.0+n_lvls[i]*j)
x = x/x.sum()
# Generate a coefficient tensor
P_vars = a_vars[i]+b_vars[i]*j
# Determine scalar moment
scalar_moment = (x*c_vars[i]).sum()
# >> DeviceArray(0.66574861, dtype=float64)
# Determine transition tensor
trans_tens = (P_vars[:,None]-P_vars[None,:])
trans_tens = trans_tens*x[None,:]*x[:,None]
trans_tens.sum(axis=1)
# >>> DeviceArray([-0.37032842, 0.16153429, 0.22063015, 0.24335933, ...
There may be a more memory efficient way of applying Boolean indexing to determine i
and j
.
Upvotes: 0