Reputation: 2195
I have a function compute(x)
where x
is a jnp.ndarray
. Now, I want to use vmap
to transform it into a function that takes a batch of arrays x[i]
, and then jit
to speed it up. compute(x)
is something like:
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
However, each array x[i]
has a different length. I can easily work around this problem by padding arrays with trailing zeros such that they all have the same length N
and vmap(compute)
can be applied on batches with shape (batch_size, N)
.
Doing so, however, leads to very_expensive_function()
to be called also on the trailing zeros of each array x[i]
. Is there a way to modify compute()
such that very_expensive_function()
is called only on a slice of x
, without interfering with vmap
and jit
?
Upvotes: 5
Views: 2961
Reputation: 1599
With JAX, when you want to jit a function to speed things up, the given batch parameter x
must be a well defined ndarray (i.e. the x[i] must have the same shapes). This is true whether or not you are using vmap
.
Now, the usual way of dealing with that is to pad these arrays. This implies that you add a mask in your parameters such that the padded values don't affect your result. For example, if I want to compute the softmax
of padded values x
of shape (bath_size, max_length)
, I need to "disable" the effect of the padded values. Here is an example:
import jax.numpy as jnp
import jax
PAD = 0
MINUS_INFINITY = -1e6
x = jnp.array([
[1, 2, 3, 4],
[1, 2, PAD, PAD],
[1, 2, 3, PAD]
])
mask = jnp.array([
[1, 1, 1, 1],
[1, 1, 0, 0],
[1, 1, 1, 0]
])
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)
It is not as trivial as padding x
. You need to actually change your computation at each step to disable the effect of the padding. In the case of softmax, you do this by setting the padded values close to minus infinity.
Finally, you can't really know in advance if the speed performance will be better with or without padding + masking. In my experience, it often leads to a good improvement on CPU, and to a very big improvement on GPU. In particular, the choice of the size of the batch has a big effect on the performance since a higher batch_size
will statistically lead to a higher max_length
, hence to a higher number of "useless" computations performed on the padded values.
Upvotes: 5