Reputation: 11633
I'm new to Equinox and JAX but wanted to use them to simulate a dynamical system.
But when I pass my system model as an Equinox module to jax.lax.scan I get the unhashable type error in the title. I understand that jax expects the function argument to be a pure function but I thought an Equinox Module would emulate that.
Here is a test script to reproduce the error
import equinox as eqx
import jax
import jax.numpy as jnp
class EqxModel(eqx.Module):
A: jax.Array
B: jax.Array
C: jax.Array
D: jax.Array
def __call__(self, states, inputs):
x = states.reshape(-1, 1)
u = inputs.reshape(-1, 1)
x_next = self.A @ x + self.B @ u
y = self.C @ x + self.D @ u
return x_next.reshape(-1), y.reshape(-1)
def simulate(model, inputs, x0):
xk = x0
outputs = []
for uk in inputs:
xk, yk = model(xk, uk)
outputs.append(yk)
outputs = jnp.stack(outputs)
return xk, outputs
A = jnp.array([[0.7, 1.0], [0.0, 1.0]])
B = jnp.array([[0.0], [1.0]])
C = jnp.array([[0.3, 0.0]])
D = jnp.array([[0.0]])
model = EqxModel(A, B, C, D)
# Test simulation
inputs = jnp.array([[0.0], [1.0], [1.0], [1.0]])
x0 = jnp.zeros(2)
xk, outputs = simulate(model, inputs, x0)
assert jnp.allclose(xk, jnp.array([2.7, 3.0]))
assert jnp.allclose(outputs, jnp.array([[0.0], [0.0], [0.0], [0.3]]))
# This raises TypeError
xk, outputs = jax.lax.scan(model, x0, inputs)
What is unhashable type: 'ArrayImpl'
referring to? Is it the arrays A, B, C, and D? In this model, these matrices are parameters and therefore should be static for the duration of the simulation.
I just found this issue thread that might be related:
Upvotes: 3
Views: 355
Reputation: 11633
Owen Lockwood (lockwo) has provided an explanation and answer in this issue thread, which I will re-iterate below.
I believe your issue is happening because jax tries to hash the function you are scanning over, but it can't hash the arrays that are in the module. There are probably a number of things that you could do to solve this, the simplest being to just curry the model, e.g.
xk, outputs = jax.lax.scan(lambda carry, y: model(carry, y), x0, inputs)
works fine
Or, re-written in terms of the variable names I am using:
xk, outputs = jax.lax.scan(lambda xk, uk: model(xk, uk), x0, inputs)
Upvotes: 1