amitjans
amitjans

Reputation: 123

Avoid automatic jitting of jax.numpy functions when using jax.jit

If I call the dummy function defined below, an error will be raised, because jnp.iscomplex(x) returns a tracer object.

But x is fixed, thus I'd expect jnp.iscomplex(x) to return False.

import jax
import jax.numpy as jnp

x = jnp.array(3)

@jax.jit
def dummy():
  if jnp.iscomplex(x):
    print("Is complex!")

Is it possible to avoid jitting jnp.iscomplex?

Upvotes: 1

Views: 267

Answers (1)

jakevdp
jakevdp

Reputation: 86330

No, you cannot normally¹ cause part of a JIT-compiled function to be executed outside the JIT context. But you may be able to do what you have in mind by accessing static attributes, namely the dtype:

@jax.jit
def dummy():
  if jnp.issubdtype(x.dtype, jnp.complexfloating):
    print("Is complex!")

This has slightly different semantics to jnp.iscomplex, which returns True or False depending on the value of the imaginary part. For some background on why it is not possible to use python control flow conditional on array values in JIT, see How To Think In JAX.

Alternatively, if you'd like to compute static functions on static values, you can use numpy values and functions rather than their jax.numpy counterparts:

import numpy as np

x = np.array(3)

@jax.jit
def dummy():
  if np.iscomplex(x):
    print("Is complex!")

¹ It is technically possible to do what you want using an external callback, but there are performance implications that mean it's not the right solution in most cases.

Upvotes: 1

Related Questions