study

使用 YOLO 模型進行圖像物件檢測時,原始輸出可能包含多個重疊的框,會在後處理階段過濾這些結果,非最大抑制這種方法只保留每個物件的最佳預測框,並移除重疊且信心分數較低的重複框。

演算法


實作

在 YOLO 函式庫中,NMS 已經被高度封裝。使用者不需自己寫演算法,而是透過參數 confiou 來控制。

from ultralytics import YOLO
 
# Load a pretrained YOLOv10n model
model = YOLO("yolov10n.pt")
 
# Perform object detection on an image
results = model.predict("image.jpg", conf=0.25, iou=0.7)
 
# Display the results
results[0].show()

其底層使用 PyTorch 預設的 torchvision.ops.nms 函式,它是不分類別的。為了讓 NMS 依不同類別篩選,YOLO 利用座標偏移將不同類別的框移動到互不重疊的區域,再一次性送入 NMS 計算。

import torch
import torchvision
 
def batched_nms(boxes, scores, class_ids, iou_threshold):
    """
    Implementation of Class-Specific NMS using the 'Offset Trick'.
    This allows processing all classes in a single CUDA kernel execution.
 
    Args:
        boxes (Tensor[N, 4]): Bounding boxes (x1, y1, x2, y2).
        scores (Tensor[N]): Confidence scores for each box.
        class_ids (Tensor[N]): Class indices (e.g., 0, 1, 2...).
        iou_threshold (float): IoU threshold for suppression.
 
    Returns:
        keep (Tensor): Indices of the boxes to keep.
    """
 
    # 1. Calculate offsets
    # Find the maximum coordinate to ensure offsets prevent overlap.
    # Each class gets a unique shift: class_id * (max_coord + padding)
    max_coordinate = boxes.max()
    offsets = class_ids.to(boxes.dtype) * (max_coordinate + 1)
 
    # 2. Apply offsets to create "Virtual Coordinates"
    # Class 0 boxes stay at [0 ~ 1000]
    # Class 1 boxes shift to [5000 ~ 6000] (Assuming max_coord ~ 4000)
    # The IoU between Class 0 and Class 1 becomes 0.
    boxes_for_nms = boxes + offsets[:, None]
 
    # 3. Perform NMS
    # Use torchvision's highly optimized CUDA implementation.
    keep_indices = torchvision.ops.nms(boxes_for_nms, scores, iou_threshold)
 
    return keep_indices

應用

在 Kibo-RPC 2025 中,已知不同類別的物件不會發生重疊。因此可簡化流程,跳過依類別分組的步驟,直接對所有檢測結果進行一次性的 NMS。以下為不檢查 classId,僅根據 IoU 與 Confidence 進行過濾的版本。

/**
 * Applies Non-Maximum Suppression (NMS) to remove overlapping detections.
 *
 * @param detections List of Detection objects to be filtered.
 * @param iouThreshold IoU threshold for filtering.
 * @return List of filtered detections object after NMS
 */
private List<Detection> nms(List<Detection> detections, float iouThreshold) {
    List<Detection> results = new ArrayList<>();
 
    if (detections.isEmpty()) {
        Log.i(TAG, "No detections to process.");
        return results;
    }
 
    // Sort detections by confidence in descending order
    detections.sort(new Comparator<Detection>() {
        @Override
        public int compare(Detection det1, Detection det2) {
        return Float.compare(det2.confidence, det1.confidence);
        }
    });
 
    // Perform NMS
    while (!detections.isEmpty()) {
        // Pick the detection with highest confidence
        Detection first = detections.remove(0);
        results.add(first);
 
        // Remove overlapping detections
        Iterator<Detection> iterator = detections.iterator();
        while (iterator.hasNext()) {
        Detection next = iterator.next();
            if (calculateIoU(first.box, next.box) >= iouThreshold) {
                iterator.remove();
            }
        }
    }
 
    return results;
}