ibebrett
ibebrett

Reputation: 161

Is vmap efficient as compared to batched ops?

I am playing around with some Jax and I want to make sure I understand the "right" way to do batching.

it seems possible to write my "model" code as working over a single "instance" of data and then rely on vmap to "batch." Is this the correct way? Other tools I have worked with in the past (pytorch, tf) typically have an "batch" dimension kind of implicit. I kind of assumed that this is how the actual GPU operations were implemented, and that there had to be some sort of inherit effeciency to this batching.

My 2 questions are:

  1. is vmap the correct/expected way to batch train models in (at least most of the time)?
  2. is it not the case that the per operation batching would be somehow faster and handled by some cuda (in the case of using cuda) function someplace more naturally? Does realize that say its not vmaping over my model parameter dimensions and use the correct batched matmuls and other ops? Or is it that the ops don't actually work like this and vmapping (naively batching over the entire sequence of calcuations) actually whats happening even in something like pytorch?

This is theoretical question. My code currently works, but I am just curious as to the "why" of my approach.

Upvotes: 2

Views: 1464

Answers (2)

jakevdp
jakevdp

Reputation: 86443

If I understand your question correctly, I think you'll find that vmap produces identical results (with identical performance) to "native" batching.

Here's a quick demonstration. Suppose you've defined a simple model for a single input:

import jax
import jax.numpy as jnp
import numpy as np

rng = np.random.default_rng(98432)

M = jnp.array(rng.normal(size=(2, 3)))
b = 1.0

def model(v, M=M, b=b):
  return jnp.tanh(M @ v + b).sum()


v = jnp.array(rng.normal(size=3))

print(model(v))
# 1.7771413

What happens when you try to run this on batched input? Well, you get an error because your model definition didn't anticipate batches:

# 5x3 = 5 batches of length-3 inputs
v_batched = jnp.array(rng.normal(size=(5, 3)))
print(model(v_batched))
#---------------------------------------------------------------------------
# TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (5,).

So what should you do? One option is to re-define your model so that it accepts batches. This takes some thought, in particular we replace the simple matrix product with an einsum representing its batched version:

def model_batched(v_batched, M=M, b=b):
  # Note: v_batched.shape = (n_batches, m)
  #       M.shape = (k, m)
  #       output.shape = (n_batches, k)
  # So replace dot with appropriate einsum
  return jnp.tanh(jnp.einsum('km,nm->nk', M, v_batched) + b).sum(1)

print(jnp.array([model(v) for v in v_batched]))  # slow loops for validation!
# [-0.14736587  0.47015858  1.8918197   0.21948916  1.0849661 ]

print(model_batched(v_batched))  # fast manually-vectorized version
# [-0.14736587  0.47015858  1.8918197   0.21948916  1.0849661 ]

But it's not great to have to re-write the model every time we want to batch an operation... this is where vmap comes in: it automatically transforms the model into a batched version (without having to rewrite the code!) and it produces the same result given the original model definintion:

print(jax.vmap(model)(v_batched))  # fast automatically-vectorized version
# [-0.14736587  0.47015858  1.8918197   0.21948916  1.0849661 ]

You might ask now which one of these approaches is more efficient: it turns out that under the hood, both the manual and automatic vectorized approaches lower to an identical sequence of operations, which you can confirm by looking at the jaxpr for each.

Here's the manually batched version:

print(jax.make_jaxpr(model_batched)(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
    c:f32[2,5] = xla_call[
      call_jaxpr={ lambda ; d:f32[2,3] e:f32[5,3]. let
          f:f32[2,5] = dot_general[
            dimension_numbers=(((1,), (1,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] d e
        in (f,) }
      name=_einsum
    ] a b
    g:f32[2,5] = add c 1.0
    h:f32[2,5] = tanh g
    i:f32[5] = reduce_sum[axes=(0,)] h
  in (i,) }

And here's the automatically-batched version:

print(jax.make_jaxpr(jax.vmap(model))(v_batched))
{ lambda a:f32[2,3]; b:f32[5,3]. let
    c:f32[2,5] = dot_general[
      dimension_numbers=(((1,), (1,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] a b
    d:f32[2,5] = add c 1.0
    e:f32[2,5] = tanh d
    f:f32[5] = reduce_sum[axes=(0,)] e
  in (f,) }

The only difference is the xla_call wrapping the einsum, which is essentially a way of naming an operation or set of operations, but you'll see that the actual sequence of operations is identical between the two approaches: it's dot_general, then add, then tanh, then reduce_sum.

So the advantage of vmap is not that it produces better or faster code, but that it allows you to efficiently run your code across batches of data without having to rewrite the model to specifically handle batched inputs.

Upvotes: 3

joel
joel

Reputation: 7877

vmap rewrites your program to use the same batching approach that NumPy, PyTorch or TensorFlow would. So yes, aside from the initial call to rewrite your program, it is as efficient.

How does that work? JAX uses the XLA compiler to execute programs. XLA works like you're used to seeing, with explicit batch dimensions in most of its API. JAX hides those batch dimensions so you don't have to think about them, but provides vmap which traverses and rewrites your program to use those batch dimensions when you need them. The same old batching you're familiar with was always available, JAX just doesn't expose it until it's needed.

Upvotes: 1

Related Questions