Reputation: 161
I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.
it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.
My 2 questions are:
This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.
Upvotes: 2
Views: 1464
Reputation: 86443
If I understand your question correctly, I think you'll find that vmap
produces identical results (with identical performance) to "native" batching.
Here's a quick demonstration. Suppose you've defined a simple model for a single input:
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.default_rng(98432)
M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0
def model(v, M=M, b=b):
return jnp.tanh(M @ v + b).sum()
v = jnp.array(rng.normal(size=3))
print(model(v))
# 1.7771413
What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:
# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).
So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum
representing its batched version:
def model_batched(v_batched, M=M, b=b):
# Note: v_batched.shape = (n_batches, m)
# M.shape = (k, m)
# output.shape = (n_batches, k)
# So replace dot with appropriate einsum
return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)
print(jnp.array([model(v) for v in v_batched])) # slow loops for validation!
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
print(model_batched(v_batched)) # fast manually-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap
comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:
print(jax.vmap(model)(v_batched)) # fast automatically-vectorized version
# [-0.14736587 0.47015858 1.8918197 0.21948916 1.0849661 ]
You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.
Here's the manually batched version:
print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = xla_call[
call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
f:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] d e
in (f,) }
name=_einsum
] a b
g:f32[2,5] = add c 1.0
h:f32[2,5] = tanh g
i:f32[5] = reduce_sum[axes=(0,)] h
in (i,) }
And here's the automatically-batched version:
print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
c:f32[2,5] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=None
preferred_element_type=None
] a b
d:f32[2,5] = add c 1.0
e:f32[2,5] = tanh d
f:f32[5] = reduce_sum[axes=(0,)] e
in (f,) }
The only difference is the xla_call
wrapping the einsum
, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general
, then add
, then tanh
, then reduce_sum
.
So the advantage of vmap
is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.
Upvotes: 3
Reputation: 7877
vmap
rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.
How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap
which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.
Upvotes: 1