DumbCoder21
DumbCoder21

Reputation: 113

vmap ops.index_update in Jax

I have the following code below and it's using a simple for loop. I was just wondering if there was a way to vmap it? Here is the original code:

import numpy as np 
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax

data = np.random.rand(192,334)

a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99) 
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)

@jax.jit
def filter_jax(y):
    for ind in range(0, len(y)):
      y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
    return y

jnpData = jnp.asarray(data)

%timeit filter_jax(jnpData).block_until_ready()

And here is my attempt at using vmap:

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])

@jax.jit
def filter_jax2(y):
  ranger = range(0, len(y))
  return jax.vmap(paraUpdate, y)(ranger)

But I receive the following error:

TypeError: vmap in_axes must be an int, None, or (nested) container with those types as leaves, but got Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.

I'm a little confused since the range is of type int so I'm not too sure what's going on.

In the end, I'm trying to get this little piece optimized as best as possible to get the lowest time.

Upvotes: 1

Views: 1156

Answers (1)

jakevdp
jakevdp

Reputation: 86320

jax.vmap can express functionality in which a single operation is independently applied across multiple axes of an input. Your function is a bit different: you have a single operation iteratively applied to a single input.

Fortunately JAX provides lax.scan which can handle this situation. The implementation would look something like this:

from jax import lax

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind

@jax.jit
def filter_jax2(y):
  ranger = jnp.arange(len(y))
  return lax.scan(paraUpdate, y, ranger)[0]

print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True

%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop

%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 µs per loop

If you change your algorithm so that you'e applying the operation to every column in the array rather than the first N columns, it can be expressed with vmap like this:

@jax.jit
def filter_jax3(y):
  f = lambda col: jscp.convolve(impulse_20, col)[:-19]
  return jax.vmap(f, in_axes=1, out_axes=1)(y)

Upvotes: 1

Related Questions