Reputation: 123
If I call the dummy function defined below, an error will be raised, because jnp.iscomplex(x)
returns a tracer object.
But x is fixed, thus I'd expect jnp.iscomplex(x)
to return False
.
import jax
import jax.numpy as jnp
x = jnp.array(3)
@jax.jit
def dummy():
if jnp.iscomplex(x):
print("Is complex!")
Is it possible to avoid jitting jnp.iscomplex
?
Upvotes: 1
Views: 267
Reputation: 86330
No, you cannot normally¹ cause part of a JIT-compiled function to be executed outside the JIT context. But you may be able to do what you have in mind by accessing static attributes, namely the dtype:
@jax.jit
def dummy():
if jnp.issubdtype(x.dtype, jnp.complexfloating):
print("Is complex!")
This has slightly different semantics to jnp.iscomplex
, which returns True
or False
depending on the value of the imaginary part. For some background on why it is not possible to use python control flow conditional on array values in JIT, see How To Think In JAX.
Alternatively, if you'd like to compute static functions on static values, you can use numpy
values and functions rather than their jax.numpy
counterparts:
import numpy as np
x = np.array(3)
@jax.jit
def dummy():
if np.iscomplex(x):
print("Is complex!")
¹ It is technically possible to do what you want using an external callback, but there are performance implications that mean it's not the right solution in most cases.
Upvotes: 1