Skip to content

Commit a0d0d45

Browse files
authored
Merge pull request #500 from mario-dg/add_nmm_to_detections
Add Non-Maximum Merging (NMM) to Detections
2 parents 4a365fb + 2ee9e08 commit a0d0d45

File tree

5 files changed

+606
-26
lines changed

5 files changed

+606
-26
lines changed

supervision/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from supervision.detection.tools.smoother import DetectionsSmoother
4848
from supervision.detection.utils import (
4949
box_iou_batch,
50+
box_non_max_merge,
5051
box_non_max_suppression,
5152
calculate_masks_centroids,
5253
clip_boxes,

supervision/detection/core.py

+192
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
1010
from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs
1111
from supervision.detection.utils import (
12+
box_iou_batch,
13+
box_non_max_merge,
1214
box_non_max_suppression,
1315
calculate_masks_centroids,
1416
extract_ultralytics_masks,
@@ -1197,3 +1199,193 @@ def with_nms(
11971199
)
11981200

11991201
return self[indices]
1202+
1203+
def with_nmm(
1204+
self, threshold: float = 0.5, class_agnostic: bool = False
1205+
) -> Detections:
1206+
"""
1207+
Perform non-maximum merging on the current set of object detections.
1208+
1209+
Args:
1210+
threshold (float, optional): The intersection-over-union threshold
1211+
to use for non-maximum merging. Defaults to 0.5.
1212+
class_agnostic (bool, optional): Whether to perform class-agnostic
1213+
non-maximum merging. If True, the class_id of each detection
1214+
will be ignored. Defaults to False.
1215+
1216+
Returns:
1217+
Detections: A new Detections object containing the subset of detections
1218+
after non-maximum merging.
1219+
1220+
Raises:
1221+
AssertionError: If `confidence` is None or `class_id` is None and
1222+
class_agnostic is False.
1223+
"""
1224+
if len(self) == 0:
1225+
return self
1226+
1227+
assert (
1228+
self.confidence is not None
1229+
), "Detections confidence must be given for NMM to be executed."
1230+
1231+
if class_agnostic:
1232+
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
1233+
else:
1234+
assert self.class_id is not None, (
1235+
"Detections class_id must be given for NMM to be executed. If you"
1236+
" intended to perform class agnostic NMM set class_agnostic=True."
1237+
)
1238+
predictions = np.hstack(
1239+
(
1240+
self.xyxy,
1241+
self.confidence.reshape(-1, 1),
1242+
self.class_id.reshape(-1, 1),
1243+
)
1244+
)
1245+
1246+
merge_groups = box_non_max_merge(
1247+
predictions=predictions, iou_threshold=threshold
1248+
)
1249+
1250+
result = []
1251+
for merge_group in merge_groups:
1252+
unmerged_detections = [self[i] for i in merge_group]
1253+
merged_detections = merge_inner_detections_objects(
1254+
unmerged_detections, threshold
1255+
)
1256+
result.append(merged_detections)
1257+
1258+
return Detections.merge(result)
1259+
1260+
1261+
def merge_inner_detection_object_pair(
1262+
detections_1: Detections, detections_2: Detections
1263+
) -> Detections:
1264+
"""
1265+
Merges two Detections object into a single Detections object.
1266+
Assumes each Detections contains exactly one object.
1267+
1268+
A `winning` detection is determined based on the confidence score of the two
1269+
input detections. This winning detection is then used to specify which
1270+
`class_id`, `tracker_id`, and `data` to include in the merged Detections object.
1271+
1272+
The resulting `confidence` of the merged object is calculated by the weighted
1273+
contribution of ea detection to the merged object.
1274+
The bounding boxes and masks of the two input detections are merged into a
1275+
single bounding box and mask, respectively.
1276+
1277+
Args:
1278+
detections_1 (Detections):
1279+
The first Detections object
1280+
detections_2 (Detections):
1281+
The second Detections object
1282+
1283+
Returns:
1284+
Detections: A new Detections object, with merged attributes.
1285+
1286+
Raises:
1287+
ValueError: If the input Detections objects do not have exactly 1 detected
1288+
object.
1289+
1290+
Example:
1291+
```python
1292+
import cv2
1293+
import supervision as sv
1294+
from inference import get_model
1295+
1296+
image = cv2.imread(<SOURCE_IMAGE_PATH>)
1297+
model = get_model(model_id="yolov8s-640")
1298+
1299+
result = model.infer(image)[0]
1300+
detections = sv.Detections.from_inference(result)
1301+
1302+
merged_detections = merge_object_detection_pair(
1303+
detections[0], detections[1])
1304+
```
1305+
"""
1306+
if len(detections_1) != 1 or len(detections_2) != 1:
1307+
raise ValueError("Both Detections should have exactly 1 detected object.")
1308+
1309+
validate_fields_both_defined_or_none(detections_1, detections_2)
1310+
1311+
xyxy_1 = detections_1.xyxy[0]
1312+
xyxy_2 = detections_2.xyxy[0]
1313+
if detections_1.confidence is None and detections_2.confidence is None:
1314+
merged_confidence = None
1315+
else:
1316+
detection_1_area = (xyxy_1[2] - xyxy_1[0]) * (xyxy_1[3] - xyxy_1[1])
1317+
detections_2_area = (xyxy_2[2] - xyxy_2[0]) * (xyxy_2[3] - xyxy_2[1])
1318+
merged_confidence = (
1319+
detection_1_area * detections_1.confidence[0]
1320+
+ detections_2_area * detections_2.confidence[0]
1321+
) / (detection_1_area + detections_2_area)
1322+
merged_confidence = np.array([merged_confidence])
1323+
1324+
merged_x1, merged_y1 = np.minimum(xyxy_1[:2], xyxy_2[:2])
1325+
merged_x2, merged_y2 = np.maximum(xyxy_1[2:], xyxy_2[2:])
1326+
merged_xyxy = np.array([[merged_x1, merged_y1, merged_x2, merged_y2]])
1327+
1328+
if detections_1.mask is None and detections_2.mask is None:
1329+
merged_mask = None
1330+
else:
1331+
merged_mask = np.logical_or(detections_1.mask, detections_2.mask)
1332+
1333+
if detections_1.confidence is None and detections_2.confidence is None:
1334+
winning_detection = detections_1
1335+
elif detections_1.confidence[0] >= detections_2.confidence[0]:
1336+
winning_detection = detections_1
1337+
else:
1338+
winning_detection = detections_2
1339+
1340+
return Detections(
1341+
xyxy=merged_xyxy,
1342+
mask=merged_mask,
1343+
confidence=merged_confidence,
1344+
class_id=winning_detection.class_id,
1345+
tracker_id=winning_detection.tracker_id,
1346+
data=winning_detection.data,
1347+
)
1348+
1349+
1350+
def merge_inner_detections_objects(
1351+
detections: List[Detections], threshold=0.5
1352+
) -> Detections:
1353+
"""
1354+
Given N detections each of length 1 (exactly one object inside), combine them into a
1355+
single detection object of length 1. The contained inner object will be the merged
1356+
result of all the input detections.
1357+
1358+
For example, this lets you merge N boxes into one big box, N masks into one mask,
1359+
etc.
1360+
"""
1361+
detections_1 = detections[0]
1362+
for detections_2 in detections[1:]:
1363+
box_iou = box_iou_batch(detections_1.xyxy, detections_2.xyxy)[0]
1364+
if box_iou < threshold:
1365+
break
1366+
detections_1 = merge_inner_detection_object_pair(detections_1, detections_2)
1367+
return detections_1
1368+
1369+
1370+
def validate_fields_both_defined_or_none(
1371+
detections_1: Detections, detections_2: Detections
1372+
) -> None:
1373+
"""
1374+
Verify that for each optional field in the Detections, both instances either have
1375+
the field set to None or both have it set to non-None values.
1376+
1377+
`data` field is ignored.
1378+
1379+
Raises:
1380+
ValueError: If one field is None and the other is not, for any of the fields.
1381+
"""
1382+
attributes = ["mask", "confidence", "class_id", "tracker_id"]
1383+
for attribute in attributes:
1384+
value_1 = getattr(detections_1, attribute)
1385+
value_2 = getattr(detections_2, attribute)
1386+
1387+
if (value_1 is None) != (value_2 is None):
1388+
raise ValueError(
1389+
f"Field '{attribute}' should be consistently None or not None in both "
1390+
"Detections."
1391+
)

0 commit comments

Comments
 (0)