Hiba Lashari
Hiba Lashari

Reputation: 1

Structured Pruning of Yolov8

I have to run the object detector on Raspberry Pi 4b for real-time object detection. For this task, I have decided to use yolov8n. I have to run the detector in real-time, and since I don't have any hardware accelerator, so I am left with the only idea to prune and quantize my model. I have to detect only 3 classes including humans, animals, and vehicles. So I have customized the coco dataset and discarded all its classes except 11 which are my target classes. Then I fine-tuned the model for 50 epochs.

My next task is to prune the model. I have learned that structured pruning is used to prune the model for edge devices. I have tried to implement it and pruned my model. Then I fine-tuned it for 5 epochs. Now I am not sure I have done it correctly or not. Inorder to check it I tried to see the parameters and found that the parameters before and after pruning are equal. The code I am using is mentioned below:

Is my strategy for optimizing the code correct? Am I doing the pruning correctly. once my pruning is done is there any way to quantize the model to int8. should I export the model to ONNX format. and what is the difference between ONNX and ONNX-runtime. I am also confused about parameters, when I print the model summary it shows 3007793 parameters, whereas when I print the parameters it shown 3012993 parameters. why is this so?

import torch
import torch.nn.utils.prune as prune
from ultralytics import YOLO

# Function to prune Conv2D layers (structured pruning for better performance)
def prune_model(model, amount=0.1):
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0)  # Channel pruning
            prune.remove(module, 'weight')  # Make pruning permanent
    return model

# Load YOLO model
model = YOLO('best.pt')

# Validate the original model
results_str = model.val(data="custom_coco.yaml")
print(f"Original mAP50-95: {results_str.box.map}")

# Access the PyTorch model
torch_model_structured = model.model

# Apply pruning
print("Pruning model...")
pruned_torch_model_structured = prune_model(torch_model_structured, amount=0.1)
print("Model pruned.")

# Save pruned model
torch.save(pruned_torch_model_structured.state_dict(), 'yolov8n_Structured_pruned_weights.pth')
print("Pruned model weights saved as 'yolov8n_Structured_pruned_weights.pth'.")

# Reload pruned weights into a YOLO model
pruned_model = YOLO('best.pt')  # Load the original YOLO model
pruned_model.model.load_state_dict(torch.load('yolov8n_Structured_pruned_weights.pth'), strict=False)

# Validate pruned model
results_str = pruned_model.val(data="custom_coco.yaml")
print(f"Pruned mAP50-95: {results_str.box.map}")

# Fine-tune pruned model
print("Fine-tuning pruned model...")
results_str = pruned_model.train(
    data='custom_coco.yaml',
    epochs=5,
    imgsz=640,
    batch=8,
    lr0=0.001  # Lower learning rate for stability
)
pruned_model.save('yolov8n_structured_pruned_finetuned.pt')
print("Pruned and fine-tuned model saved as 'yolov8n_structured_pruned_finetuned.pt'.")

# Validate fine-tuned model
fine_tuned_model = YOLO('yolov8n_structured_pruned_finetuned.pt')
results_str = fine_tuned_model.val(data="custom_coco.yaml")
print(f"Fine-tuned mAP50-95: {results_str.box.map}")

Output of this code:

Ultralytics 8.3.31  Python-3.10.14 torch-2.4.1+cu118 CUDA:0 (NVIDIA GeForce GTX 1060 6GB, 6144MiB)
Model summary (fused): 168 layers, 3,007,793 parameters, 0 gradients, 8.1 GFLOPs
val: Scanning J:\quantization\datasets\custom_data\labels\val.cache... 3321 images, 0 backgrounds, 0 corrupt: 100%|██████████| 3321/3321 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [00:57<00:00,  3.61it/s]
                   all       3321      15491       0.72      0.599      0.669      0.481
                person       2693      10777      0.798      0.663      0.761      0.532
               bicycle        149        314        0.7      0.379      0.438      0.257
                   car        535       1918      0.698      0.521      0.585      0.378
            motorcycle        159        367      0.709      0.572      0.657      0.418
                   bus        189        283      0.815      0.653      0.743      0.615
                 truck        250        414      0.573      0.379      0.451      0.297
                   cat        184        202      0.782      0.835      0.867      0.672
                   dog        177        218      0.717      0.683      0.748      0.602
                 horse        128        272      0.742      0.624      0.741      0.558
                 sheep         65        354      0.624      0.653      0.666       0.46
                   cow         87        372      0.756      0.625      0.702      0.496
Speed: 0.4ms preprocess, 4.1ms inference, 0.0ms loss, 2.2ms postprocess per image
Results saved to runs\detect\val22
Original mAP50-95: 0.4805636008285908
Pruning model...
Model pruned.
Pruned model weights saved as 'yolov8n_Structured_pruned_weights.pth'.
Ultralytics 8.3.31  Python-3.10.14 torch-2.4.1+cu118 CUDA:0 (NVIDIA GeForce GTX 1060 6GB, 6144MiB)
Model summary (fused): 168 layers, 3,007,793 parameters, 0 gradients, 8.1 GFLOPs
val: Scanning J:\quantization\datasets\custom_data\labels\val.cache... 3321 images, 0 backgrounds, 0 corrupt: 100%|██████████| 3321/3321 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [01:07<00:00,  3.08it/s]
                   all       3321      15491      0.455   0.000775   4.48e-05   9.65e-06
                person       2693      10777    0.00331   0.000186    3.5e-05   1.46e-05
               bicycle        149        314          1          0          0          0
                   car        535       1918   0.000977    0.00834   0.000458   9.15e-05
            motorcycle        159        367          1          0          0          0
                   bus        189        283          0          0          0          0
                 truck        250        414          1          0          0          0
                   cat        184        202          1          0          0          0
                   dog        177        218          0          0          0          0
                 horse        128        272          0          0          0          0
                 sheep         65        354          0          0          0          0
                   cow         87        372          1          0          0          0
Speed: 0.4ms preprocess, 4.0ms inference, 0.0ms loss, 2.6ms postprocess per image
Results saved to runs\detect\val23
Pruned mAP50-95: 9.649495810499685e-06
Fine-tuning pruned model...
New https://pypi.org/project/ultralytics/8.3.34 available  Update with 'pip install -U ultralytics'
Ultralytics 8.3.31  Python-3.10.14 torch-2.4.1+cu118 CUDA:0 (NVIDIA GeForce GTX 1060 6GB, 6144MiB)
engine\trainer: task=detect, mode=train, model=best.pt, data=custom_coco.yaml, epochs=5, time=None, patience=100, batch=8, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=train9, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=True, opset=None, workspace=4, nms=False, lr0=0.001, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=12.0, kobj=1.0, label_smoothing=0.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, bgr=0.0, mosaic=1.0, mixup=0.0, copy_paste=0.0, copy_paste_mode=flip, auto_augment=randaugment, erasing=0.4, crop_fraction=1.0, cfg=None, tracker=botsort.yaml, save_dir=runs\detect\train9

                   from  n    params  module                                       arguments
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
  7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
  8                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
  9                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 11             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 12                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 14             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 15                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 16                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 17            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 18                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 19                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 20             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 21                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 22        [15, 18, 21]  1    753457  ultralytics.nn.modules.head.Detect           [11, [64, 128, 256]]
Model summary: 225 layers, 3,012,993 parameters, 3,012,977 gradients, 8.2 GFLOPs

Transferred 70/355 items from pretrained weights
TensorBoard: Start with 'tensorboard --logdir runs\detect\train9', view at http://localhost:6006/
Freezing layer 'model.22.dfl.conv.weight'
AMP: running Automatic Mixed Precision (AMP) checks...
AMP: checks passed
train: Scanning J:\quantization\datasets\custom_data\labels\train.cache... 78140 images, 0 backgrounds, 0 corrupt: 100%|██████████| 78140/78140 [00:00<?, ?it/s]
val: Scanning J:\quantization\datasets\custom_data\labels\val.cache... 3321 images, 0 backgrounds, 0 corrupt: 100%|██████████| 3321/3321 [00:00<?, ?it/s]
Plotting labels to runs\detect\train9\labels.jpg...
optimizer: 'optimizer=auto' found, ignoring 'lr0=0.001' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically...
optimizer: AdamW(lr=0.000667, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
TensorBoard: model graph visualization added
Image sizes 640 train, 640 val
Using 8 dataloader workers
Logging results to runs\detect\train9
Starting training for 5 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
        1/5      1.76G      1.382      1.496      1.337         12        640: 100%|██████████| 9768/9768 [41:16<00:00,  3.94it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [00:59<00:00,  3.50it/s]
                   all       3321      15491      0.643      0.494      0.546      0.363

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
        2/5      1.59G      1.178      1.189      1.202         24        640: 100%|██████████| 9768/9768 [37:35<00:00,  4.33it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [00:51<00:00,  4.04it/s]
                   all       3321      15491      0.668      0.533      0.598      0.408

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
        3/5      1.61G      1.147      1.139      1.183         36        640: 100%|██████████| 9768/9768 [36:33<00:00,  4.45it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [00:50<00:00,  4.12it/s]
                   all       3321      15491      0.689      0.546      0.613      0.425

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
        4/5      1.51G      1.115      1.086      1.164         14        640: 100%|██████████| 9768/9768 [36:16<00:00,  4.49it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [00:49<00:00,  4.23it/s]
                   all       3321      15491      0.697      0.558       0.63      0.437

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
        5/5      1.66G      1.088      1.038      1.147         22        640: 100%|██████████| 9768/9768 [37:34<00:00,  4.33it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [00:53<00:00,  3.88it/s]
                   all       3321      15491      0.706      0.573       0.64      0.447

5 epochs completed in 3.234 hours.
Optimizer stripped from runs\detect\train9\weights\last.pt, 6.2MB
Optimizer stripped from runs\detect\train9\weights\best.pt, 6.2MB

Validating runs\detect\train9\weights\best.pt...
Ultralytics 8.3.31  Python-3.10.14 torch-2.4.1+cu118 CUDA:0 (NVIDIA GeForce GTX 1060 6GB, 6144MiB)
Model summary (fused): 168 layers, 3,007,793 parameters, 0 gradients, 8.1 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [01:01<00:00,  3.39it/s]
                   all       3321      15491      0.705      0.573       0.64      0.447
                person       2693      10777      0.795      0.641      0.741       0.51
               bicycle        149        314      0.687      0.347      0.419      0.235
                   car        535       1918      0.697      0.493      0.563      0.351
            motorcycle        159        367      0.731      0.549      0.633      0.379
                   bus        189        283      0.784      0.642      0.728      0.588
                 truck        250        414      0.584       0.36      0.414      0.264
                   cat        184        202      0.727      0.807      0.837      0.632
                   dog        177        218      0.704      0.643      0.703      0.549
                 horse        128        272      0.782       0.61      0.721      0.528
                 sheep         65        354      0.567      0.625      0.625      0.422
                   cow         87        372      0.702       0.59       0.66      0.457
Speed: 0.3ms preprocess, 3.7ms inference, 0.0ms loss, 2.4ms postprocess per image
Results saved to runs\detect\train9
Pruned and fine-tuned model saved as 'yolov8n_structured_pruned_finetuned.pt'.
Ultralytics 8.3.31  Python-3.10.14 torch-2.4.1+cu118 CUDA:0 (NVIDIA GeForce GTX 1060 6GB, 6144MiB)
Model summary (fused): 168 layers, 3,007,793 parameters, 0 gradients, 8.1 GFLOPs
val: Scanning J:\quantization\datasets\custom_data\labels\val.cache... 3321 images, 0 backgrounds, 0 corrupt: 100%|██████████| 3321/3321 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 208/208 [01:04<00:00,  3.25it/s]
                   all       3321      15491      0.706      0.573       0.64      0.447
                person       2693      10777      0.796      0.641      0.741       0.51
               bicycle        149        314      0.688      0.347      0.418      0.235
                   car        535       1918      0.698      0.492      0.562      0.351
            motorcycle        159        367       0.73      0.552      0.632       0.38
                   bus        189        283      0.785       0.64      0.728      0.588
                 truck        250        414      0.585       0.36      0.414      0.264
                   cat        184        202      0.731      0.806      0.837      0.631
                   dog        177        218      0.703      0.641      0.703       0.55
                 horse        128        272      0.783       0.61      0.721      0.528
                 sheep         65        354       0.57      0.624      0.624      0.424
                   cow         87        372      0.699      0.589      0.661      0.455
Speed: 0.4ms preprocess, 4.2ms inference, 0.0ms loss, 2.7ms postprocess per image
Results saved to runs\detect\val24
Fine-tuned mAP50-95: 0.44696036711098114

Upvotes: 0

Views: 81

Answers (0)

Related Questions