Davi Alefe
Davi Alefe

Reputation: 125

Output shape of jax's vmap

I'm trying to figure out how exactly jax.vmap's output work. As an example, taking the following code:

func = lambda x: [[x,x],[x,x]]

vfunc = vmap(func,0, out_axes=WHAT)
example_array = np.array([1,2,3])
vfunc(example_array)

I'd like vfunc's output array in the example to be something like:

[ [[1,1],[1,1]],
  [[2,2],[2,2]],
  [[3,3],[3,3]]
]

Is there a way to get that from setting out_axes to something specifically or would I have to run some posterior shape transformations?

Upvotes: 1

Views: 737

Answers (2)

jakevdp
jakevdp

Reputation: 86453

It is not possible to use vmap alone to get the result you have in mind.

One thing to understand about vmap is it works on axes of JAX Arrays; any time those JAX arrays are in containers (such as lists, tuples, or dicts), the in_axes and out_axes arguments refer to the individual arrays within those containers. So, for example if you return a tuple, it looks like this:

@vmap
def f(x):
  return (x, 2 * x)

x = jnp.arange(3)
print(f(x))
# (DeviceArray([0, 1, 2], dtype=int32), DeviceArray([0, 2, 4], dtype=int32))

Notice each element in the tuple is vmapped separately.

Similarly, if you return a nested list of individual arrays, the vmap applies to each individual array and is returned in the same structure:

@vmap
def f(x):
  return [x, [x, x]]

print(f(jnp.array([1, 2])))
# [DeviceArray([1, 2], dtype=int32),
#  [DeviceArray([1, 2], dtype=int32), DeviceArray([1, 2], dtype=int32)]]

Notice the nested list in the result has the same structure as the nested list of the input.

So, as you can see, when you return something like [[x, x], [x, x]], the output structure will always be [[y, y], [y, y]]. There is no way to use vmap to make the output structure different than the input structure as is requested in your question, so vmap by itself cannot be used for what you want to do.

Now, if you change your question slightly and rather than returning a nested list of arrays, return an array constructed from such a nested list, then the vmap out_axes argument applies to the return value as a whole (rather than applying to each individual array in the output). You could then use that array directly, or if you actually want the nested list structure you unpack the output like this:

@vmap
def func(x):
  return jnp.array([[x,x],[x,x]])

example_array = jnp.array([1,2,3])
result = func(example_array)
print(list(map(list, result)))
# [[DeviceArray([1, 1], dtype=int32), DeviceArray([1, 1], dtype=int32)],
#  [DeviceArray([2, 2], dtype=int32), DeviceArray([2, 2], dtype=int32)],
#  [DeviceArray([3, 3], dtype=int32), DeviceArray([3, 3], dtype=int32)]]

Upvotes: 3

I'mahdi
I'mahdi

Reputation: 24059

You can change your code like below:

import jax.numpy as jnp
from jax import vmap

func = lambda x: jnp.array([[x,x],[x,x]])
# Or
# def func(x) :
#     return jnp.array([[x,x],[x,x]])


vfunc = vmap(func, in_axes=0, out_axes=0)
example_array = jnp.array([1,2,3])
vfunc(example_array)

Outout:

DeviceArray([[[1, 1],
              [1, 1]],

             [[2, 2],
              [2, 2]],

             [[3, 3],
              [3, 3]]], dtype=int32)

Upvotes: 0

Related Questions