Schach21
Schach21

Reputation: 450

vmap gives inconsistent shape error when trying to calculate gradient per sample

I am trying to implement a two layer neural network and get the gradient of the second layer per sample.

My code looks like this:

x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=jnp.float32)
y = jnp.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=jnp.float32)

W1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=jnp.float32)
W2 = random.uniform(key, shape=(10,), minval=1, maxval=2, dtype=jnp.float32)
b = jnp.linspace(0, -9, 10, dtype=jnp.float32)
b = jnp.reshape(b, (1,10))

def predict(W1, W2, b, x):
  f1 = jnp.einsum('i,j->ji', W1, x)+b
  f1 = nn.relu(f1)
  f2 = jnp.einsum('i,ji->j', W2, f1)
  return f2

def loss(W1, W2, b, x, y):
  preds = predict(W1, W2, b, x)
  return jnp.mean(jnp.square(y-preds))

perex_grads = vmap(grad(loss, argnums=1), in_axes= (0, None, 0, 0, 0))
pers_grads = perex_grads(W1, W2, b, x, y)

I ran loss and can do grad(loss) just fine. Running vmap is the actual problem.

The exact error I get is:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
arg 0 has shape (10,) and axis 0 is to be mapped
arg 1 has shape (10,) and axis None is to be mapped
arg 2 has shape (1, 10) and axis 0 is to be mapped
arg 3 has shape (11,) and axis 0 is to be mapped
arg 4 has shape (11,) and axis 0 is to be mapped
so
arg 0 has an axis to be mapped of size 10
arg 2 has an axis to be mapped of size 1
args 3, 4 have axes to be mapped of size 11

It is my first time using Jax, and my google search didn't help me resolve the issue, plus the documentation was not very clear to me. I'd appreciate if anyone can help me.

Upvotes: 2

Views: 1512

Answers (1)

jakevdp
jakevdp

Reputation: 86443

The issue is exactly what the error message says: in order to vmap an operation over multiple arrays, the dimension of the mapped axes in each array must be equal. In your arrays, the dimensions are not equal: you passed in_axes=(0, None, 0, 0, 0) for arguments W1, W2, b, x, y, but W1.shape[0] = 10, b.shape[0] = 1, x.shape[0] = 11, and y.shape[0] = 11.

Because these are not equal, you get this error. To prevent this error, you should only vmap over array axes of the same length.

For example, if you want the gradients with respect to W2 computed per pair of W1, W2 inputs, it might look something like this (note the updated predict function and updated in_axes):

import jax.numpy as jnp
from jax import random, nn, grad, vmap

key = random.PRNGKey(0)

x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=jnp.float32)
y = jnp.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=jnp.float32)

W1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=jnp.float32)
W2 = random.uniform(key, shape=(10,), minval=1, maxval=2, dtype=jnp.float32)
b = jnp.linspace(0, -9, 10, dtype=jnp.float32)
b = jnp.reshape(b, (1,10))

def predict(W1, W2, b, x):
  # Since you're vmapping over W1 and W2, your function needs to
  # handle scalar values, so we cast to 1D if necessary.
  W1 = jnp.atleast_1d(W1)
  W2 = jnp.atleast_1d(W2)

  f1 = jnp.einsum('i,j->ji', W1, x)+b
  f1 = nn.relu(f1)
  f2 = jnp.einsum('i,ji->j', W2, f1)
  return f2

def loss(W1, W2, b, x, y):
  preds = predict(W1, W2, b, x)
  return jnp.mean(jnp.square(y-preds))

perex_grads = vmap(grad(loss, argnums=1), in_axes= (0, 0, None, None, None))
pers_grads = perex_grads(W1, W2, b, x, y)

Upvotes: 2

Related Questions