Reputation: 450
I am trying to implement a two layer neural network and get the gradient of the second layer per sample.
My code looks like this:
x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=jnp.float32)
y = jnp.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=jnp.float32)
W1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=jnp.float32)
W2 = random.uniform(key, shape=(10,), minval=1, maxval=2, dtype=jnp.float32)
b = jnp.linspace(0, -9, 10, dtype=jnp.float32)
b = jnp.reshape(b, (1,10))
def predict(W1, W2, b, x):
f1 = jnp.einsum('i,j->ji', W1, x)+b
f1 = nn.relu(f1)
f2 = jnp.einsum('i,ji->j', W2, f1)
return f2
def loss(W1, W2, b, x, y):
preds = predict(W1, W2, b, x)
return jnp.mean(jnp.square(y-preds))
perex_grads = vmap(grad(loss, argnums=1), in_axes= (0, None, 0, 0, 0))
pers_grads = perex_grads(W1, W2, b, x, y)
I ran loss and can do grad(loss)
just fine. Running vmap
is the actual problem.
The exact error I get is:
ValueError: vmap got inconsistent sizes for array axes to be mapped:
arg 0 has shape (10,) and axis 0 is to be mapped
arg 1 has shape (10,) and axis None is to be mapped
arg 2 has shape (1, 10) and axis 0 is to be mapped
arg 3 has shape (11,) and axis 0 is to be mapped
arg 4 has shape (11,) and axis 0 is to be mapped
so
arg 0 has an axis to be mapped of size 10
arg 2 has an axis to be mapped of size 1
args 3, 4 have axes to be mapped of size 11
It is my first time using Jax, and my google search didn't help me resolve the issue, plus the documentation was not very clear to me. I'd appreciate if anyone can help me.
Upvotes: 2
Views: 1512
Reputation: 86443
The issue is exactly what the error message says: in order to vmap an operation over multiple arrays, the dimension of the mapped axes in each array must be equal. In your arrays, the dimensions are not equal: you passed in_axes=(0, None, 0, 0, 0)
for arguments W1, W2, b, x, y
, but W1.shape[0] = 10
, b.shape[0] = 1
, x.shape[0] = 11
, and y.shape[0] = 11
.
Because these are not equal, you get this error. To prevent this error, you should only vmap over array axes of the same length.
For example, if you want the gradients with respect to W2
computed per pair of W1, W2
inputs, it might look something like this (note the updated predict
function and updated in_axes
):
import jax.numpy as jnp
from jax import random, nn, grad, vmap
key = random.PRNGKey(0)
x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=jnp.float32)
y = jnp.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=jnp.float32)
W1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=jnp.float32)
W2 = random.uniform(key, shape=(10,), minval=1, maxval=2, dtype=jnp.float32)
b = jnp.linspace(0, -9, 10, dtype=jnp.float32)
b = jnp.reshape(b, (1,10))
def predict(W1, W2, b, x):
# Since you're vmapping over W1 and W2, your function needs to
# handle scalar values, so we cast to 1D if necessary.
W1 = jnp.atleast_1d(W1)
W2 = jnp.atleast_1d(W2)
f1 = jnp.einsum('i,j->ji', W1, x)+b
f1 = nn.relu(f1)
f2 = jnp.einsum('i,ji->j', W2, f1)
return f2
def loss(W1, W2, b, x, y):
preds = predict(W1, W2, b, x)
return jnp.mean(jnp.square(y-preds))
perex_grads = vmap(grad(loss, argnums=1), in_axes= (0, 0, None, None, None))
pers_grads = perex_grads(W1, W2, b, x, y)
Upvotes: 2