Reputation: 113
I have the following code below and it's using a simple for loop. I was just wondering if there was a way to vmap it? Here is the original code:
import numpy as np
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax
data = np.random.rand(192,334)
a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99)
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)
@jax.jit
def filter_jax(y):
for ind in range(0, len(y)):
y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
return y
jnpData = jnp.asarray(data)
%timeit filter_jax(jnpData).block_until_ready()
And here is my attempt at using vmap:
def paraUpdate(y, ind):
return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
@jax.jit
def filter_jax2(y):
ranger = range(0, len(y))
return jax.vmap(paraUpdate, y)(ranger)
But I receive the following error:
TypeError: vmap in_axes must be an int, None, or (nested) container with those types as leaves, but got Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.
I'm a little confused since the range is of type int so I'm not too sure what's going on.
In the end, I'm trying to get this little piece optimized as best as possible to get the lowest time.
Upvotes: 1
Views: 1156
Reputation: 86320
jax.vmap
can express functionality in which a single operation is independently applied across multiple axes of an input. Your function is a bit different: you have a single operation iteratively applied to a single input.
Fortunately JAX provides lax.scan
which can handle this situation. The implementation would look something like this:
from jax import lax
def paraUpdate(y, ind):
return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind
@jax.jit
def filter_jax2(y):
ranger = jnp.arange(len(y))
return lax.scan(paraUpdate, y, ranger)[0]
print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
# True
%timeit filter_jax(jnpData).block_until_ready()
# 10 loops, best of 3: 28.6 ms per loop
%timeit filter_jax2(jnpData).block_until_ready()
# 1000 loops, best of 3: 519 µs per loop
If you change your algorithm so that you'e applying the operation to every column in the array rather than the first N columns, it can be expressed with vmap
like this:
@jax.jit
def filter_jax3(y):
f = lambda col: jscp.convolve(impulse_20, col)[:-19]
return jax.vmap(f, in_axes=1, out_axes=1)(y)
Upvotes: 1