Object Counting Across Frames in YOLOv5 Video Detection

I'm working on an object detection project using YOLOv5 in Python, where I need to detect and count bicycles and motorcycles in a video as they cross a vertical line in the middle of the frame. However, I’m facing an issue where the objects are being counted in every frame, leading to inflated counts.

In my code, I track each object using a unique object_id that’s generated based on the object’s class and position. Despite this, the detection seems to reset in each frame, and the objects get counted multiple times.

Here’s a snippet of the code (the full version is below):


import torch
import cv2
import numpy as np

# Load the YOLOv5 model (we are using the small version, yolov5s)
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

# Load the video
video_path = 'C:/Users/komyj/Downloads/yolov5-master/yolov5-master/data/images/test_2.mp4'
output_path = 'C:/Users/komyj/Downloads/yolov5-master/yolov5-master/data/images/output.mp4'
cap = cv2.VideoCapture(video_path)

if not cap.isOpened():
    print("Error opening the video.")
    exit()

# Get video dimensions
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
line_position = frame_width // 2  # Vertical line in the middle of the video

# Initialize counters
count_bikes = 0
count_motorcycles = 0

# Dictionary to store tracked objects with their ID and position history
tracked_objects = {}

# Function to check if an object has crossed the line
def is_object_passing(line_position, current_x, previous_x):
    return previous_x < line_position and current_x >= line_position

# Setup VideoWriter to save the output video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

if not out.isOpened():
    print("Error opening VideoWriter.")
    cap.release()
    exit()

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Convert the image to RGB (OpenCV uses BGR)
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    # Object detection
    results = model(img_rgb)
    
    # Render the detection results
    img_rendered = results.render()[0]
    
    # Convert back to BGR for OpenCV
    img_bgr = cv2.cvtColor(img_rendered, cv2.COLOR_RGB2BGR)
    
    # Display the vertical line
    cv2.line(img_bgr, (line_position, 0), (line_position, frame_height), (255, 0, 0), 2)

    # Process detected objects
    for detection in results.pred[0]:
        x1, y1, x2, y2, conf, class_id = detection
        center_x = int((x1 + x2) / 2)
        center_y = int((y1 + y2) / 2)

        # Generate a unique ID for each object based on its class and position
        object_id = f"{int(class_id)}-{int(x1)}-{int(y1)}"

        # If the object is already being tracked, update its position
        if object_id in tracked_objects:
            prev_center_x = tracked_objects[object_id]['center_x']
            tracked_objects[object_id]['center_x'] = center_x
        else:
            # If the object is not being tracked, add it with its ID and initial position
            tracked_objects[object_id] = {'class_id': class_id, 'center_x': center_x, 'counted': False}

        # Check if the object has already been counted and if it has crossed the line
        if not tracked_objects[object_id]['counted'] and is_object_passing(line_position, center_x, prev_center_x):
            if class_id == 1:  # ID for a bicycle
                count_bikes += 1
            elif class_id == 3:  # ID for a motorcycle
                count_motorcycles += 1
            tracked_objects[object_id]['counted'] = True  # Mark the object as counted

        # Draw a rectangle around the detected object
        if class_id == 1:  # ID for a bicycle
            cv2.rectangle(img_bgr, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
            cv2.putText(img_bgr, "Bike", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
        
        elif class_id == 3:  # ID for a motorcycle
            cv2.rectangle(img_bgr, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
            cv2.putText(img_bgr, "Motorcycle", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)

    # Display the bike and motorcycle counts
    cv2.putText(img_bgr, f'Bikes Count: {count_bikes}', (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.putText(img_bgr, f'Motorcycles Count: {count_motorcycles}', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
    
    # Save the frame to the new video
    out.write(img_bgr)

cap.release()
out.release()
cv2.destroyAllWindows()

# Print the final results
print(f'Final Bike Count: {count_bikes}')
print(f'Final Motorcycle Count: {count_motorcycles}')

Here’s a summary of what I’ve tried:

Upvotes: 0

Views: 102

Answers (0)

Related Questions