user25609470
user25609470

Reputation:

jax sum creates a huge intermediate array slowing down GPU performance

I'm trying to create a jax function that selects a set of values from a 2D array and adds them together in a vectorized manner.

More specifically, given an (R x C) array data and a (1 x C) array of ints indxs_to_sum I can compute data[indxs_to_sum, np.arange(C)].sum(), which essentially adds 1 element from each columnd of data.

The problem comes when I want to do this for several arrays like indxs_to_sum, i.e. I want indxs_to_sum to be a (B x C) array of ints (i.e. a batch of arrays of indices). My approach (see below) is to vmap the previous function, which compiles into a big Gather XLA operation. This, however, doesn't scale as good as expected as there is a bottleneck in the GPU whenever creating the intermediate array before reducing everything with a sum.

The way I understand it, all I need to do is a for loop (which I can code up with a fori_loop) where, in each iteration, I take a 1D slice from data and add the values indexed by indxs_to_sum to a carry-on array, with the advantage that each iteration can be run independently and in parallel so that no intermediate array is created at all.

In brief, my question is whether there's any obvious refactoring that I could do to func in the example below so as to avoid creating a huge intermediate array.

import time
import numpy as np
import jax
import jax.numpy as jnp

batch_size = 10_000 --> This number can grow as high as 10_000_000 depending on the available GPU

# Create mock data
data = jnp.array(np.random.random((400, 10000))) --> # 2D array from which data should be selected
batch_to_sum = jnp.array( --> # Array of 1D arrays. Each 1
    np.random.randint(data.shape[0], size=(batch_size, data.shape[1]))
)

def _func(indxs_to_sum): --> This is the function
    return data[indxs_to_sum, jnp.arange(data.shape[1])].sum()
func = jax.jit(jax.vmap(_func, in_axes=0, out_axes=0))

print(" *** vmap ***")
t0 = time.time()
res = func(batch_to_sum).block_until_ready()
print(f"vmap + jit 1st pass: {(time.time() - t0)}")
t0 = time.time()
res = func(batch_to_sum).block_until_ready()
print(f"vmap + jit 2nd pass: {(time.time() - t0)}")

> Output:
> *** vmap ***
> vmap + jit 1st pass: 0.13960623741149902
> vmap + jit 2nd pass: 0.004172086715698242

Upvotes: 1

Views: 144

Answers (0)

Related Questions