Reputation: 2195
I am learning to use JAX and I have some doubts about the use of jit
and vmap
that I couldn't solve by reading the docs.
Does it make a difference to jit
several functions separately and then jit
the function that uses them? For example, if I have the functions foo()
and bar()
and a function
@jax.jit
def fooBar(x):
return foo(x) + bar(x)
Is there any difference if foo()
and bar()
are already jitted?
Should I jit
a function after I vmap
it? In the example above, should I do jax.jit(jax.vmal(fooBar))
or just jax.vmap(fooBar)
?
Upvotes: 2
Views: 533
Reputation: 86320
When it comes to performance of code execution, there is no difference between jitting functions separately and jitting once at the outer function (functionally there is one subtle difference: jit-compiling the inner function will wrap the contents in an xla_call
primitive, but this makes little to no difference for the final compilation & execution).
When using vmap
on the other hand, there is no implicit compilation. vmap(f)
will be executed in eager mode, while jit(vmap(f))
will be just-in-time compiled and generally result in faster execution.
Upvotes: 2