Federico Taschin
Federico Taschin

Reputation: 2195

JAX batching with different lengths

I have a function compute(x) where x is a jnp.ndarray. Now, I want to use vmap to transform it into a function that takes a batch of arrays x[i], and then jit to speed it up. compute(x) is something like:

def compute(x):
    # ... some code
    y = very_expensive_function(x)
    return y

However, each array x[i] has a different length. I can easily work around this problem by padding arrays with trailing zeros such that they all have the same length N and vmap(compute) can be applied on batches with shape (batch_size, N).

Doing so, however, leads to very_expensive_function() to be called also on the trailing zeros of each array x[i]. Is there a way to modify compute() such that very_expensive_function() is called only on a slice of x, without interfering with vmap and jit?

Upvotes: 5

Views: 2961

Answers (1)

Robin
Robin

Reputation: 1599

With JAX, when you want to jit a function to speed things up, the given batch parameter x must be a well defined ndarray (i.e. the x[i] must have the same shapes). This is true whether or not you are using vmap.

Now, the usual way of dealing with that is to pad these arrays. This implies that you add a mask in your parameters such that the padded values don't affect your result. For example, if I want to compute the softmax of padded values x of shape (bath_size, max_length), I need to "disable" the effect of the padded values. Here is an example:

import jax.numpy as jnp
import jax

PAD = 0
MINUS_INFINITY = -1e6

x = jnp.array([ 
       [1, 2, 3, 4],
       [1, 2, PAD, PAD],
       [1, 2, 3, PAD]
    ])

mask = jnp.array([
           [1, 1, 1, 1],
           [1, 1, 0, 0],
           [1, 1, 1, 0]
       ])
       
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)    

It is not as trivial as padding x. You need to actually change your computation at each step to disable the effect of the padding. In the case of softmax, you do this by setting the padded values close to minus infinity.

Finally, you can't really know in advance if the speed performance will be better with or without padding + masking. In my experience, it often leads to a good improvement on CPU, and to a very big improvement on GPU. In particular, the choice of the size of the batch has a big effect on the performance since a higher batch_size will statistically lead to a higher max_length, hence to a higher number of "useless" computations performed on the padded values.

Upvotes: 5

Related Questions