使用 YOLO 模型進行圖像物件檢測時,原始輸出可能包含多個重疊的框,會在後處理階段過濾這些結果,非最大抑制這種方法只保留每個物件的最佳預測框,並移除重疊且信心分數較低的重複框。
演算法
實作
在 YOLO 函式庫中,NMS 已經被高度封裝。使用者不需自己寫演算法,而是透過參數 conf 與 iou 來控制。
from ultralytics import YOLO
# Load a pretrained YOLOv10n model
model = YOLO("yolov10n.pt")
# Perform object detection on an image
results = model.predict("image.jpg", conf=0.25, iou=0.7)
# Display the results
results[0].show()其底層使用 PyTorch 預設的 torchvision.ops.nms 函式,它是不分類別的。為了讓 NMS 依不同類別篩選,YOLO 利用座標偏移將不同類別的框移動到互不重疊的區域,再一次性送入 NMS 計算。
import torch
import torchvision
def batched_nms(boxes, scores, class_ids, iou_threshold):
"""
Implementation of Class-Specific NMS using the 'Offset Trick'.
This allows processing all classes in a single CUDA kernel execution.
Args:
boxes (Tensor[N, 4]): Bounding boxes (x1, y1, x2, y2).
scores (Tensor[N]): Confidence scores for each box.
class_ids (Tensor[N]): Class indices (e.g., 0, 1, 2...).
iou_threshold (float): IoU threshold for suppression.
Returns:
keep (Tensor): Indices of the boxes to keep.
"""
# 1. Calculate offsets
# Find the maximum coordinate to ensure offsets prevent overlap.
# Each class gets a unique shift: class_id * (max_coord + padding)
max_coordinate = boxes.max()
offsets = class_ids.to(boxes.dtype) * (max_coordinate + 1)
# 2. Apply offsets to create "Virtual Coordinates"
# Class 0 boxes stay at [0 ~ 1000]
# Class 1 boxes shift to [5000 ~ 6000] (Assuming max_coord ~ 4000)
# The IoU between Class 0 and Class 1 becomes 0.
boxes_for_nms = boxes + offsets[:, None]
# 3. Perform NMS
# Use torchvision's highly optimized CUDA implementation.
keep_indices = torchvision.ops.nms(boxes_for_nms, scores, iou_threshold)
return keep_indices應用
在 Kibo-RPC 2025 中,已知不同類別的物件不會發生重疊。因此可簡化流程,跳過依類別分組的步驟,直接對所有檢測結果進行一次性的 NMS。以下為不檢查 classId,僅根據 IoU 與 Confidence 進行過濾的版本。
/**
* Applies Non-Maximum Suppression (NMS) to remove overlapping detections.
*
* @param detections List of Detection objects to be filtered.
* @param iouThreshold IoU threshold for filtering.
* @return List of filtered detections object after NMS
*/
private List<Detection> nms(List<Detection> detections, float iouThreshold) {
List<Detection> results = new ArrayList<>();
if (detections.isEmpty()) {
Log.i(TAG, "No detections to process.");
return results;
}
// Sort detections by confidence in descending order
detections.sort(new Comparator<Detection>() {
@Override
public int compare(Detection det1, Detection det2) {
return Float.compare(det2.confidence, det1.confidence);
}
});
// Perform NMS
while (!detections.isEmpty()) {
// Pick the detection with highest confidence
Detection first = detections.remove(0);
results.add(first);
// Remove overlapping detections
Iterator<Detection> iterator = detections.iterator();
while (iterator.hasNext()) {
Detection next = iterator.next();
if (calculateIoU(first.box, next.box) >= iouThreshold) {
iterator.remove();
}
}
}
return results;
}