safetyduck
safetyduck

Reputation: 6874

Does equinox (jax) do no batch dim broadcasting and expects you to use vmap instead?

https://docs.kidger.site/equinox/api/nn/mlp/#equinox.nn.MLP

The only way I was able to use MLP is like this

import jax
import equinox as eqx
import numpy as np


jax.vmap(eqx.nn.MLP(in_size=12, out_size=4, width_size=6, depth=5, key=key))(np.random.randn(5, 12)

Is this the intended usage? It differs a bit from other frameworks then. But maybe safer.

Upvotes: 1

Views: 320

Answers (1)

Patrick Kidger
Patrick Kidger

Reputation: 101

Yup, this is intended!

Every layer in eqx.nn acts on a single batch element, and you can apply them to batches by calling jax.vmap, exactly as you're doing.

See also this FAQ entry: https://docs.kidger.site/equinox/faq/#how-do-i-input-higher-order-tensors-eg-with-batch-dimensions-into-my-model

I hope that helps!

Upvotes: 4

Related Questions