Federico Taschin
Federico Taschin

Reputation: 2195

JAX does jitting functions separately change performances?

I am learning to use JAX and I have some doubts about the use of jit and vmap that I couldn't solve by reading the docs.

  1. Does it make a difference to jit several functions separately and then jit the function that uses them? For example, if I have the functions foo() and bar() and a function

    @jax.jit 
    def fooBar(x):
        return foo(x) + bar(x)
    

    Is there any difference if foo() and bar() are already jitted?

  2. Should I jit a function after I vmap it? In the example above, should I do jax.jit(jax.vmal(fooBar)) or just jax.vmap(fooBar)?

Upvotes: 2

Views: 533

Answers (1)

jakevdp
jakevdp

Reputation: 86320

When it comes to performance of code execution, there is no difference between jitting functions separately and jitting once at the outer function (functionally there is one subtle difference: jit-compiling the inner function will wrap the contents in an xla_call primitive, but this makes little to no difference for the final compilation & execution).

When using vmap on the other hand, there is no implicit compilation. vmap(f) will be executed in eager mode, while jit(vmap(f)) will be just-in-time compiled and generally result in faster execution.

Upvotes: 2

Related Questions