Reputation:
I'm trying to create a jax
function that selects a set of values from a
2D array and adds them together in a vectorized manner.
More specifically, given an (R x C) array data
and a (1 x C) array of ints indxs_to_sum
I can compute data[indxs_to_sum, np.arange(C)].sum()
, which essentially adds 1 element
from each columnd of data
.
The problem comes when I want to do this for several arrays like indxs_to_sum
, i.e.
I want indxs_to_sum
to be a (B x C) array of ints (i.e. a batch of arrays of indices).
My approach (see below) is to vmap
the previous function, which compiles into a big
Gather
XLA operation. This, however, doesn't scale as good as expected as there is a
bottleneck in the GPU whenever creating the intermediate array before reducing
everything with a sum
.
The way I understand it, all I need to do is a for loop (which I can code up with a fori_loop
)
where, in each iteration, I take a 1D slice from data
and add the values indexed by
indxs_to_sum
to a carry-on array, with the advantage that each iteration can
be run independently and in parallel so that no intermediate array is created at all.
In brief, my question is whether there's any obvious refactoring that I could do to func
in the example below so as to avoid creating a huge intermediate array.
import time
import numpy as np
import jax
import jax.numpy as jnp
batch_size = 10_000 --> This number can grow as high as 10_000_000 depending on the available GPU
# Create mock data
data = jnp.array(np.random.random((400, 10000))) --> # 2D array from which data should be selected
batch_to_sum = jnp.array( --> # Array of 1D arrays. Each 1
np.random.randint(data.shape[0], size=(batch_size, data.shape[1]))
)
def _func(indxs_to_sum): --> This is the function
return data[indxs_to_sum, jnp.arange(data.shape[1])].sum()
func = jax.jit(jax.vmap(_func, in_axes=0, out_axes=0))
print(" *** vmap ***")
t0 = time.time()
res = func(batch_to_sum).block_until_ready()
print(f"vmap + jit 1st pass: {(time.time() - t0)}")
t0 = time.time()
res = func(batch_to_sum).block_until_ready()
print(f"vmap + jit 2nd pass: {(time.time() - t0)}")
> Output:
> *** vmap ***
> vmap + jit 1st pass: 0.13960623741149902
> vmap + jit 2nd pass: 0.004172086715698242
Upvotes: 1
Views: 144