Tom McLean
Tom McLean

Reputation: 6349

Is there a module to convert a tensorflow NN to Jax?

There is a libary to convert Jax functions to Tensorflow functions. Is there a similar library to convert TensorFlow functions to Jax functions?

Upvotes: 2

Views: 1128

Answers (4)

Junmin Hao
Junmin Hao

Reputation: 151

Is https://github.com/google-deepmind/tf2jax what you were looking for? It only works for TF v2 though.

Upvotes: 1

John Zhang
John Zhang

Reputation: 11

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md.

"jax2tf.call_tf: for using TensorFlow functions in a JAX context, e.g., to call a TensorFlow library or a SavedModel inside a JAX function."

That is what you need. So you can call tf function under jax context.

for example:

# Compute cos with TF and sin with JAX
def cos_tf_sin_jax(x):
  return jax.numpy.sin(jax2tf.call_tf(cos_tf)(x))

Upvotes: 1

Erling Olsen
Erling Olsen

Reputation: 740

To my knowledge there is no library similar to the one you mentioned to convert TensorFlow functions to Jax functions. I'm sorry

Upvotes: 0

jakevdp
jakevdp

Reputation: 86328

No, there is no library supported by the JAX team to convert tensorflow into JAX in a manner similar to how jax.experimental.jax2tf converts JAX code to tensorflow, and I have not seen any such library developed by others.

Upvotes: 3

Related Questions