Reputation: 11
Having an issue converting Pytorch
script into onnx
. I'm getting errors saying:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(uint8)) , expected: (tensor(float))
when I look at the model it shows:
format
ONNX v8
producer
pytorch 2.1.1
imports
ai.onnx v17
stride
32
names
{0: 'player'}
graph
main_graph
images
name: images
tensor: float32[1,3,320,320]
output0
name: output0
tensor: float32[1,6300,6]
Which to my understanding is float32
.
So why is it saying data type is uint8
?
#model export command: python .\export.py --weights ./best.pt --include onnx --imgsz 320 320 --device 0
` print("[INFO] Loading ONNX model")
self.onnx_session = onnxruntime.InferenceSession(onnx_model_path)`
I tried
`frame = frame.transpose((2, 0, 1)) # Change from HWC to CHW format
frame = frame[np.newaxis, :] # Add batch size dimension
frame = frame.astype(np.float32) / 255.0
results = self.onnx_session.run(['output0'], {'images': frame})`
The code prepares the image frame for ONNX
model inference by changing its format to CHW
(Channels, Height, Width), adding a batch size dimension, and normalizing pixel values. The processed image is then fed to the ONNX
model for inference, and the results are stored in the results variable.
Which leads to this:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: images for the following indices index: 1 Got: 4 Expected: 3 Please fix either the inputs or the model.
What am i missing?
Upvotes: 1
Views: 672
Reputation: 901
From the error message you are getting
Got invalid dimensions for input: images for the following indices index: 1
Got: 4 Expected: 3
We see that inference fails because ONNX model expects the input array to have three channels in the second dimension, whereas the actual input has four. Considering your input is an image,
images tensor: float32[1,3,320,320]
the second dimension is most likely for red, green and blue colour channels. Since your input has one channel too many, my bet is that your image is in RGBA
format, hence the one extra channel (alpha) and the error message.
Try converting your image to RGB before feeding it into the graph.
from PIL import Image
image = Image.open("/path/to/image")
image = image.convert("RGB")
frame = np.asarray(image)
...
Upvotes: 1