bbartling
bbartling

Reputation: 3494

convert pytorch model to ONNX

How to convert a pytorch model to ONNX? I am trying to use this method on Python 3.7:

import torch

model = torch.load("./yolov7x.pt")

#torch.onnx.export(model, "yolo_v7x.onnx")

Even with the commented last line in the 3 lines of code, loading this errors out:

Traceback (most recent call last):
  File "C:\Users\convert_onx.py", line 5, in <module>
    model = torch.load("./yolov7x.pt")
  File "C:\Users\Python37\lib\site-packages\torch\serialization.py", line 594, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "C:\Users\Python37\lib\site-packages\torch\serialization.py", line 853, in _load
    result = unpickler.load()
ModuleNotFoundError: No module named 'models'

This is the git repo I am working with the Yolo Model 7x: enter image description here

Ultimate use case is to use this model on Intel's Open VINO toolkit that requires PyTorch models to be converted to ONYX.

Upvotes: 1

Views: 14632

Answers (2)

u1234x1234
u1234x1234

Reputation: 2510

When you are loading the pickled model the source tree must match the one that used when the model was saved. So

ModuleNotFoundError: No module named 'models'

expects this directory to be in your python path: https://github.com/WongKinYiu/yolov7/tree/main/models

To export to ONNX:

  1. Clone the repo https://github.com/WongKinYiu/yolov7
git clone https://github.com/WongKinYiu/yolov7
  1. Set the correct path to it.
import sys
sys.path.insert(0, './yolov7')

or you can set PYTHONPATH environment variable

  1. Also you may need to have specific torch version. I've checked and it seems torch==1.8.0 works fine

Example:

import torch
import sys

sys.path.insert(0, './yolov7')

device = torch.device('cpu')
model = torch.load('yolov7x.pt', map_location=device)['model'].float()
torch.onnx.export(model, torch.zeros((1, 3, 640, 640)), 'yolov7.onnx', opset_version=12)

After that the model was exported to ONNX (visualized with netron): enter image description here

Usually it is better to save weights as state_dict and keep the source code that can reconstruct the torch.nn.Module so then you can safely use:

model.load_state_dict(torch.load('weights.pt'))

Upvotes: 2

bbartling
bbartling

Reputation: 3494

In the Yolo v7 repo: https://github.com/WongKinYiu/yolov7

Just use the Google Collab notebooks provided in the Yolo v7 repo to convert the pytorch to different model types...easy as pie!

Upvotes: 0

Related Questions