Reputation: 6349
There is a libary to convert Jax functions to Tensorflow functions. Is there a similar library to convert TensorFlow functions to Jax functions?
Upvotes: 2
Views: 1128
Reputation: 151
Is https://github.com/google-deepmind/tf2jax what you were looking for? It only works for TF v2 though.
Upvotes: 1
Reputation: 11
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md.
"jax2tf.call_tf: for using TensorFlow functions in a JAX context, e.g., to call a TensorFlow library or a SavedModel inside a JAX function."
That is what you need. So you can call tf function under jax context.
for example:
# Compute cos with TF and sin with JAX
def cos_tf_sin_jax(x):
return jax.numpy.sin(jax2tf.call_tf(cos_tf)(x))
Upvotes: 1
Reputation: 740
To my knowledge there is no library similar to the one you mentioned to convert TensorFlow functions to Jax functions. I'm sorry
Upvotes: 0
Reputation: 86328
No, there is no library supported by the JAX team to convert tensorflow into JAX in a manner similar to how jax.experimental.jax2tf
converts JAX code to tensorflow, and I have not seen any such library developed by others.
Upvotes: 3