iNVA
iNVA

Reputation: 66

How do I convert a saved tensorflow model to pytorch which is not compatible with ONNX?

Following the instructions provided in this colab notebook, I downloaded a pre-trained TF model. Now, I want to convert it to pytorch. I tried several options for the conversion, including ONNX, but I am unable to do so. When I spoke to the authors of the repo, they said they trained the model in Jax and converted it to TensorFlow for the release.

Details on the ONNX method I used to convert the model.

onnx_model, _ = tf2onnx.convert.from_function(
  function=concrete_func,
  input_signature=[tf.TensorSpec(shape=[None, 84, 84, 4], dtype=tf.float32)],
  opset=13
)

I get an error saying '_WrapperFunction' object has no attribute 'get_concrete_function'

It also seems like the saved model doesn't have attribute variables. So, I am unsure how to proceed. Inference is however possible by loading the model, which tells me that there are weights somewhere.

Inference code:

# @title Reload saved model
with tf.device('/device:GPU:0'):
  pvn = tf.saved_model.load(gamedir.as_posix())
     

# @title Perform forward pass to get PVN features
# Atari observation: [batch_size, width, height, frame_stack]
obs = np.zeros((1, 84, 84, 4))

with tf.device('/device:GPU:0'):
  features = pvn(obs).numpy()
  print(f"Shape: {features.shape!r}")
  print(features

Upvotes: 0

Views: 41

Answers (0)

Related Questions