DavidJ
DavidJ

Reputation: 418

Irregular/Inhomogeneous Arrays with JAX

What is the recommended approach to implement array behaviour/methods on irregular/inhomogeneous data (possesses some inherient dimensionality) within JAX?

Two principle options come to mind:

  1. make homogeneous and use a mask
  2. flatten and implement custom methods (i.e. broadcasting and reduction)

Clearly option 1 is favourable as this requires less implementation overhead (and consequently validation/testing). The concern is memory complexity - in situations where this is paramount (to avoid having to distribute the array) is there a better alternative to option 2 (that can exploit the highly optimised array methods)?

EDIT: The following implements a concrete example which contains sparsity.

import jax as jx
import jax.numpy as jnp
jx.config.update("jax_enable_x64", True)


# Problem specific variables (static)
n_vars = 3 # Number of variable sets
n_smps = 10 # Maximum number of set elements
p_smps = 0.2 # Representation of problem sparsity


# Each set contains a differing number of elements (binomial random for example)
n_lvls = jx.random.bernoulli(
    jx.random.PRNGKey(0),
    p_smps,
    (n_vars, n_smps)
).sum(axis=1, dtype='i4')


# Derived quantities depend on constant coefficients (uniform random for example)
a_vars = jx.random.uniform(jx.random.PRNGKey(1), (n_vars, ), dtype='f8')

b_vars = jx.random.uniform(jx.random.PRNGKey(2), (n_vars, ), dtype='f8')
b_vars = 10.0*b_vars

c_vars = jx.random.uniform(jx.random.PRNGKey(3), (n_vars, ), dtype='f8')
c_vars = 2.0*c_vars

The problem is intrinsically represented with a 7 element state. What follows is one implementation of option 1

### Homogeneous with mask ###

# Define the level index array
i_smps = jnp.arange(n_smps, dtype='i4')
mask = n_lvls[:,None]>i_smps[None,:]

# Generate an initial state that respects the unity axiom
x_vars = 1.0/(1.0+n_lvls[:,None]*i_smps[None,:]).astype('f8')
x_vars = jnp.where(mask, x_vars, 0.0)
x_vars = x_vars/x_vars.sum()


# Generate a coefficient tensor
P_vars = a_vars[:,None]+b_vars[:,None]*i_smps[None,:]
P_vars = jnp.where(mask, P_vars, 0.0)


# Determine a scalar moment
scalar_moment = (x_vars*c_vars[:,None]).sum()
# >>> DeviceArray(0.66574861, dtype=float64)

# Determine a transition tensor
trans_tens = (P_vars[:,:,None,None]-P_vars[None,None,:,:])
trans_tens = trans_tens*x_vars[None,None,:,:]*x_vars[:,:,None,None]
trans_tens.sum(axis=(2,3))
# >>>  DeviceArray([[-0.37032842,  0.16153429,  0.22063015,  0.24335933, ...

Ensuring homogeneity increases this to 30. Furthermore, computing derived quantities involves numerous multiply by zero operations.

Upvotes: 0

Views: 320

Answers (1)

DavidJ
DavidJ

Reputation: 418

This is one approach to flatten and reduce sparsity using Boolean indexing.

# Required for static Boolean indexing
import numpy as np
n_lvls = np.asarray(n_lvls, dtype='i4')

### Flatten with Boolean Indexing ###

i_grid = np.arange(n_lvls.size, dtype='i4')
j_grid = np.arange(n_lvls.max(), dtype='i4')

# Determine the boolean mask
mask = n_lvls[:,None]>j_grid[None,:]
bc_size = (i_grid.size, j_grid.size)

i = np.broadcast_to(i_grid[:,None], bc_size)[mask]
# >>> array([0, 0, 0, 0, 1, 1, 2], dtype=int32)

j = np.broadcast_to(j_grid[None,:], bc_size)[mask]
# >>> array([0, 1, 2, 3, 0, 1, 0], dtype=int32)

# Generate an initial state that respects the unity axiom
x = 1.0/(1.0+n_lvls[i]*j)
x = x/x.sum()


# Generate a coefficient tensor
P_vars = a_vars[i]+b_vars[i]*j


# Determine scalar moment
scalar_moment = (x*c_vars[i]).sum()
# >> DeviceArray(0.66574861, dtype=float64)


# Determine transition tensor
trans_tens = (P_vars[:,None]-P_vars[None,:])
trans_tens = trans_tens*x[None,:]*x[:,None]
trans_tens.sum(axis=1)
# >>> DeviceArray([-0.37032842,  0.16153429,  0.22063015,  0.24335933, ...

There may be a more memory efficient way of applying Boolean indexing to determine i and j.

Upvotes: 0

Related Questions