Reputation: 309
I've have the following doubt about Jax. I'll use an example from the official optax docs to illustrate it:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
In this example, the function step
uses the variable optimizer
despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation
is not a supported type). However, the same function uses other variables that are instead passed as parameters (i.e., params, opt_state, batch, labels
). I understand that jax functions needs to be pure in order to be jitted, but what about input (read-only) variables. Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it's in the step
function scope? What if this variable is not constant but modified between separate step
calls? Are they treated like static arguments if accessed directly? Or are they simply jitted away and so modifications of such parameters will not be considered?
To be more specific, let's look at the following example:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_learning_rate # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels)
extra_learning_rate = 0.01 # does this affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels)
return params
vs
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels, extra_lr):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_lr # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
extra_learning_rate = 0.01 # does this now affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
return params
From my limited experiments, they perform differently as the second step
call doesn't uses the new learning rates in the global case and also no 're-jitting' happens, however I'd like to know if there's any standard practice/rules I need to be aware of. I'm writing a library where performance is fundamental and I don't want to miss some jit optimizations because I'm doing things wrong.
Upvotes: 5
Views: 1788
Reputation: 86320
During JIT tracing, JAX treats global values as implicit arguments to the function being traced. You can see this reflected in the jaxpr representing the function.
Here are two simple functions that return equivalent results, one with implicit arguments and one with explicit:
import jax
import jax.numpy as jnp
def f_explicit(a, b):
return a + b
def f_implicit(b):
return a_global + b
a_global = jnp.arange(5.0)
b = jnp.ones(5)
print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }
print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }
Notice the only difference in the two jaxprs is that in f_implicit
, the a
variable comes before the semicolon: this is the way that jaxpr
representations indicate the argument is passed via closure rather than via an explicit argument. But the computation generated by these two functions will be identical.
That said, one difference to be aware of is that when an argument passed by closure is a hashable constant, it will be treated as static within the traced function (similar when explicit arguments are marked static via static_argnums
or static_argnames
within jax.jit
):
a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }
Notice in the jaxpr representation the constant value is inserted directly as an argument to the add
operation. The explicit way to to get the same result for a JIT-compiled function would look something like this:
from functools import partial
@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
return a + b
Upvotes: 7