re4per
re4per

Reputation: 11

YOLOv9e-seg not able to do segmentation for the entire image

I have trained YOLOv9e-seg on image of size 2000x2000 but when I use the trained model to do the segmentation prediction on an image with same resolution, it cuts out some part of the image from the segmentation prediction. When I try to predict with the same model for an image of size 1024x1024, it is able to do it for the entire images you can see from the image below. The cerebral cortex which is the light green part (top of brain, the big region) as you can see is clearly segmented in image2 and only partly in image1.

enter image description here

Below is the code from training:

import torch
from ultralytics import YOLO
import os

from ultralytics import YOLO

# Load pretrained YOLOv9 models
model = YOLO("yolov9e-seg.pt")  # Load extended segmentation model

# # Train with TensorBoard logging on specific GPUs
model.train(
    data='training_data/brain_data.yaml',
    epochs=3000,
    imgsz=2000,
    batch=8,
    project='brain_segmentation',
    name='overnight_run',
    device=[0,1,4,5],
    close_mosaic=15,  # Disable mosaic augmentation in final epochs
    save_period= 500
)

Prediction and plotting code is:

label_colors = {
    "Thalamus": "#6ff151",
    "Caudate nucleus": "#65ce0c",
    "Putamen": "#d6cb4d",
    "Globus pallidus": "#e5feba",
    "Nucleus accumbens": "#8e6f03",
    "Internal capsule": "#8e4527",
    "Substantia innominata": "#4f9b02",
    "Fornix": "#ac561e",
    "Anterior commissure": "#b7cc25",
    "Ganglionic eminence": "#876f47",
    "Hypothalamus": "#3fe936",
    "Amygdala": "#74af15",
    "Hippocampus": "#bd885c",
    "Choroid plexus": "#b5e60a",
    "Lateral ventricle": "#88b151",
    "Olfactory tubercle": "#ecad5e",
    "Pretectum": "#707166",
    "Inferior colliculus": "#a1830f",
    "Superior colliculus": "#ff9b3d",
    "Tegmentum": "#eaeea1",
    "Pons": "#cc7e39",
    "Medulla": "#fcae1b",
    "Cerebellum": "#4e4137",
    "Corpus callosum": "#de998d",
    "Cerebral cortex": "#50fa0d"
}

import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
from matplotlib.patches import Polygon as mplPolygon

def hex_to_rgb(hex_color):
    """Convert hex color code to RGB tuple (0-255 scale)"""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def yolo_predict_and_visualize(image_path, model_path='best.pt', title_plot="Prediction", imgsz=2000):
    # Load model and verify classes
    model = YOLO(model_path)
    print("Model device:", model.device)
    print("Model class names:", model.names)
    assert list(model.names.values()) == list(label_colors.keys()), "Class name mismatch!"
    
    # Load and verify image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not load image at {image_path}")
    
    print(f"Original image shape: {image.shape}")
    
    # Convert to RGB for visualization
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Enhanced prediction with larger size handling
    results = model(
        image,
        imgsz = [imgsz,imgsz], #max(image.shape[:2]),  # Use the larger dimension of the image
        conf=0.1,           # Lower confidence threshold
        iou=0.3,            # IoU threshold
        augment=False,      # Disable augmentation for inference
        verbose=True,       # Enable verbose output
        retina_masks=True,  # Enable high-quality segmentation
        max_det=300,        # Increase maximum detections
        device='cpu'        # Explicitly set device
    )
    
    result = results[0]
    
    # Get detection details
    class_ids = result.boxes.cls.cpu().numpy().astype(int) if result.boxes else np.array([])
    scores = result.boxes.conf.cpu().numpy() if result.boxes else np.array([])
    masks = result.masks if result.masks else None
    
    print(f"\nDetected {len(class_ids)} instances")
    if len(class_ids) == 0:
        print("No detections found!")
        plt.figure(figsize=(16, 10))
        plt.imshow(image_rgb)
        plt.title("No Detections Found", fontsize=16)
        plt.axis('off')
        plt.show()
        return
    
    # Create figure with proper size
    fig_size = (20, int(20 * image.shape[0] / image.shape[1]))
    fig, ax = plt.subplots(figsize=fig_size)
    ax.imshow(image_rgb)
    
    # Prepare legend and region tracking
    legend_handles = []
    handled_classes = set()
    total_regions = 0
    class_region_counts = {}

    # Process each detection
    for i, (mask, class_id, score) in enumerate(zip(masks, class_ids, scores)):
        class_name = model.names[class_id]
        color = hex_to_rgb(label_colors[class_name])
        
        if class_name not in class_region_counts:
            class_region_counts[class_name] = 0
        
        # Convert mask to numpy array
        mask_np = mask.data[0].cpu().numpy()
        mask_uint8 = (mask_np * 255).astype(np.uint8)
        
        # Find contours
        contours, _ = cv2.findContours(
            mask_uint8,
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_TC89_KCOS
        )
        
        print(f"\n{'='*40}")
        print(f"Instance {i+1}: {class_name} (Confidence: {score:.2f})")
        
        for j, contour in enumerate(contours):
            # Filter small regions
            if cv2.contourArea(contour) < 50:
                continue
                
            total_regions += 1
            class_region_counts[class_name] += 1
            
            # Process polygon
            polygon = contour.squeeze(1).astype(np.float32)
            color_normalized = tuple(c / 255 for c in color)
            facecolor = color_normalized + (0.3,)
            
            patch = mplPolygon(
                polygon,
                closed=True,
                edgecolor=color_normalized,
                facecolor=facecolor,
                linewidth=1.5,
                linestyle='-'
            )
            ax.add_patch(patch)
        
        # Add to legend if new class
        if class_id not in handled_classes:
            legend_handles.append(plt.Line2D(
                [0], [0],
                marker='s',
                markersize=12,
                linewidth=0,
                color=np.array(color)/255,
                markerfacecolor=np.array(color)/255,
                label=f"{class_name}"
            ))
            handled_classes.add(class_id)

    # Print statistics
    print("\nDetailed Statistics:")
    for class_name, count in sorted(class_region_counts.items()):
        print(f"{class_name}: {count} regions")
    print(f"\nTotal regions detected: {total_regions}")
    
    # Configure plot
    plt.title(f"{title_plot}\nBrain Segmentation Results\n{len(class_ids)} Instances, {total_regions} Regions",
             fontsize=16, pad=20)
    plt.axis('off')
    
    # Create legend
    plt.legend(
        handles=legend_handles,
        bbox_to_anchor=(1.05, 1),
        loc='upper left',
        borderaxespad=0.5,
        fontsize=10,
        title="Detected Regions",
        title_fontsize=12
    )
    
    plt.tight_layout()
    plt.show()
    
    return results

# Usage
try:
   results = yolo_predict_and_visualize(
        image_path="image_2000.jpg",
        model_path="brain_segmentation/overnight_run/weights/best.pt",
        title_plot="Prediction 2000",
    )
except Exception as e:
    print(f"Error: {str(e)}")

yolo_predict_and_visualize(
        image_path="image_1024.jpg",
        model_path="brain_segmentation/overnight_run/weights/epoch500.pt",
        title_plot="Prediction 1024",
        imgsz=1024
    )

My training data is organized as below:

training_data
|--->images
|   |--->train
|   |--->val
|--->labels
|   |--->train
|   |--->val
|--->brain_data.yaml

brain_data.yaml

path: /storage/user/work_my/new_yolo_test/training_data
train:
  - images/train  # Path to training images
  - labels/train  # Path to training annotations
val:
  - images/val  # Path to validation images
  - labels/val  # Path to validation annotations

nc: 25
names: ['Thalamus', 'Caudate nucleus', 'Putamen', 'Globus pallidus', 'Nucleus accumbens', 'Internal capsule', 'Substantia innominata', 'Fornix', 'Anterior commissure', 'Ganglionic eminence', 'Hypothalamus', 'Amygdala', 'Hippocampus', 'Choroid plexus', 'Lateral ventricle', 'Olfactory tubercle', 'Pretectum', 'Inferior colliculus', 'Superior colliculus', 'Tegmentum', 'Pons', 'Medulla', 'Cerebellum', 'Corpus callosum', 'Cerebral cortex']

Note:

Upvotes: 0

Views: 19

Answers (0)

Related Questions