Reputation: 29
I'm working on an object detection project using YOLOv5 in Python, where I need to detect and count bicycles and motorcycles in a video as they cross a vertical line in the middle of the frame. However, I’m facing an issue where the objects are being counted in every frame, leading to inflated counts.
In my code, I track each object using a unique object_id
that’s generated based on the object’s class and position. Despite this, the detection seems to reset in each frame, and the objects get counted multiple times.
Here’s a snippet of the code (the full version is below):
import torch
import cv2
import numpy as np
# Load the YOLOv5 model (we are using the small version, yolov5s)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
# Load the video
video_path = 'C:/Users/komyj/Downloads/yolov5-master/yolov5-master/data/images/test_2.mp4'
output_path = 'C:/Users/komyj/Downloads/yolov5-master/yolov5-master/data/images/output.mp4'
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error opening the video.")
exit()
# Get video dimensions
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
line_position = frame_width // 2 # Vertical line in the middle of the video
# Initialize counters
count_bikes = 0
count_motorcycles = 0
# Dictionary to store tracked objects with their ID and position history
tracked_objects = {}
# Function to check if an object has crossed the line
def is_object_passing(line_position, current_x, previous_x):
return previous_x < line_position and current_x >= line_position
# Setup VideoWriter to save the output video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
if not out.isOpened():
print("Error opening VideoWriter.")
cap.release()
exit()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert the image to RGB (OpenCV uses BGR)
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Object detection
results = model(img_rgb)
# Render the detection results
img_rendered = results.render()[0]
# Convert back to BGR for OpenCV
img_bgr = cv2.cvtColor(img_rendered, cv2.COLOR_RGB2BGR)
# Display the vertical line
cv2.line(img_bgr, (line_position, 0), (line_position, frame_height), (255, 0, 0), 2)
# Process detected objects
for detection in results.pred[0]:
x1, y1, x2, y2, conf, class_id = detection
center_x = int((x1 + x2) / 2)
center_y = int((y1 + y2) / 2)
# Generate a unique ID for each object based on its class and position
object_id = f"{int(class_id)}-{int(x1)}-{int(y1)}"
# If the object is already being tracked, update its position
if object_id in tracked_objects:
prev_center_x = tracked_objects[object_id]['center_x']
tracked_objects[object_id]['center_x'] = center_x
else:
# If the object is not being tracked, add it with its ID and initial position
tracked_objects[object_id] = {'class_id': class_id, 'center_x': center_x, 'counted': False}
# Check if the object has already been counted and if it has crossed the line
if not tracked_objects[object_id]['counted'] and is_object_passing(line_position, center_x, prev_center_x):
if class_id == 1: # ID for a bicycle
count_bikes += 1
elif class_id == 3: # ID for a motorcycle
count_motorcycles += 1
tracked_objects[object_id]['counted'] = True # Mark the object as counted
# Draw a rectangle around the detected object
if class_id == 1: # ID for a bicycle
cv2.rectangle(img_bgr, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
cv2.putText(img_bgr, "Bike", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
elif class_id == 3: # ID for a motorcycle
cv2.rectangle(img_bgr, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
cv2.putText(img_bgr, "Motorcycle", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
# Display the bike and motorcycle counts
cv2.putText(img_bgr, f'Bikes Count: {count_bikes}', (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.putText(img_bgr, f'Motorcycles Count: {count_motorcycles}', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
# Save the frame to the new video
out.write(img_bgr)
cap.release()
out.release()
cv2.destroyAllWindows()
# Print the final results
print(f'Final Bike Count: {count_bikes}')
print(f'Final Motorcycle Count: {count_motorcycles}')
Here’s a summary of what I’ve tried:
I track the objects using a dictionary (tracked_objects
) to store their object_id
, position (center_x
), and a counted
flag.
I check if an object has crossed the line using its current and previous positions.
I only increment the count when an object crosses the line from left to right and hasn’t been counted yet.
Upvotes: 0
Views: 102