Baschdl
Baschdl

Reputation: 435

vmap over a list in jax

Using jax, I try to calculate gradients per sample, process them and then bring them in the normal form to calculate a normal parameter update. My working code looks like

differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)

# some code

gradients_summed_over_samples = []
    for layer in gradients:
        (dw, db) = layer
        (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
        gradients_summed_over_samples.append((dw, db))

where gradients is of the form list(tuple(DeviceArray(...), DeviceArray(...)), ...).

Now I tried to rewrite the loop as vmap (not sure if it brings a speedup in the end)

def sum_samples(layer):
    (dw, db) = layer
    (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))

vmap(sum_samples)(gradients)

but sum_samples is called only once and not for each element in the list.

Is the list the problem or do I understand something else wrong?

Upvotes: 2

Views: 4217

Answers (1)

jakevdp
jakevdp

Reputation: 86310

jax.vmap will only be mapped over jax array inputs, not inputs that are lists of arrays or tuples. In addition, vmapped functions cannot modify inputs in-place; the functions should return a value, and this return value will be stacked with other return values to construct the output

For example, you could modify the function you defined and use it like this:

import jax.numpy as np
from jax import random

def sum_samples(layer):
    (dw, db) = layer
    (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
    return np.array([dw, db])

key = random.PRNGKey(1701)
data = random.uniform(key, (10, 2, 20))

result = vmap(sum_samples)(data)
print(result.shape)
# (10, 2)

Side note: if you're using this approach, the vmapped function above can be more concisely expressed as:

def sum_samples(layer):
    return layer.sum(1)

Upvotes: 4

Related Questions