Reputation: 175
I am trying to perform inference on my custom YOLOv5 model. The official documentation uses the default detect.py
script for inference. I have written my own python script but I cannot access the predicted class and the bounding box coordinates from the output of the model. Here is my code:
import torch
model = torch.hub.load('ultralytics/yolov5', 'custom', path_or_model='best.pt')
predictions = model("my_image.png")
print(predictions)
Upvotes: 7
Views: 25986
Reputation: 106
YOLOv5 🚀 PyTorch Hub models allow for simple model loading and inference in a pure python environment without using detect.py
.
This example loads a pretrained YOLOv5s model from PyTorch Hub as model
and passes an image for inference. 'yolov5s'
is the YOLOv5 'small' model. For details on all available models please see the README. Custom models can also be loaded, including custom trained PyTorch models and their exported variants, i.e. ONNX, TensorRT, TensorFlow, OpenVINO YOLOv5 models.
import torch
# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # or yolov5m, yolov5l, yolov5x, etc.
# model = torch.hub.load('ultralytics/yolov5', 'custom', 'path/to/best.pt') # custom trained model
# Images
im = 'https://ultralytics.com/images/zidane.jpg' # or file, Path, URL, PIL, OpenCV, numpy, list
# Inference
results = model(im)
# Results
results.print() # or .show(), .save(), .crop(), .pandas(), etc.
results.xyxy[0] # im predictions (tensor)
results.pandas().xyxy[0] # im predictions (pandas)
# xmin ymin xmax ymax confidence class name
# 0 749.50 43.50 1148.0 704.5 0.874023 0 person
# 2 114.75 195.75 1095.0 708.0 0.624512 0 person
# 3 986.00 304.00 1028.0 420.0 0.286865 27 tie
See YOLOv5 PyTorch Hub Tutorial for details.
Upvotes: 8
Reputation: 241
results = model(input_images)
labels, cord_thres = results.xyxyn[0][:, -1].numpy(), results.xyxyn[0][:, :-1].numpy()
This will give you labels, coordinates, and thresholds for each object detected, you can use it to plot bounding boxes. You can check out this repo for more detailed code.
https://github.com/akash-agni/Real-Time-Object-Detection
Upvotes: 12