Skip to content

Commit 5a9ee49

Browse files
authored
Fix d/c IoU for different batch sizes (#6343)
* Fix d/c IoU for different batch sizes (#6338) * Fix tests * Fix linter
1 parent 053feed commit 5a9ee49

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

test/test_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,11 @@ def _generate_int_input():
12661266
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
12671267

12681268
def _generate_int_expected():
1269-
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1269+
return [
1270+
[1.0000, 0.1875, -0.4444],
1271+
[0.1875, 1.0000, -0.5625],
1272+
[-0.4444, -0.5625, 1.0000],
1273+
]
12701274

12711275
def _generate_float_input():
12721276
return [
@@ -1357,7 +1361,11 @@ def _generate_int_input() -> List[List[int]]:
13571361
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
13581362

13591363
def _generate_int_expected() -> List[List[float]]:
1360-
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1364+
return [
1365+
[1.0000, 0.1875, -0.4444],
1366+
[0.1875, 1.0000, -0.5625],
1367+
[-0.4444, -0.5625, 1.0000],
1368+
]
13611369

13621370
def _generate_float_input() -> List[List[float]]:
13631371
return [

torchvision/ops/boxes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
325325

326326
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
327327

328-
w_pred = boxes1[:, 2] - boxes1[:, 0]
329-
h_pred = boxes1[:, 3] - boxes1[:, 1]
328+
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
329+
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
330330

331331
w_gt = boxes2[:, 2] - boxes2[:, 0]
332332
h_gt = boxes2[:, 3] - boxes2[:, 1]
333333

334-
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
334+
v = (4 / (torch.pi ** 2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
335335
with torch.no_grad():
336336
alpha = v / (1 - iou + v + eps)
337337
return diou - alpha * v
@@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
358358

359359
boxes1 = _upcast(boxes1)
360360
boxes2 = _upcast(boxes2)
361-
diou, _ = _box_diou_iou(boxes1, boxes2)
361+
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
362362
return diou
363363

364364

@@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
375375
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
376376
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
377377
# The distance between boxes' centers squared.
378-
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
378+
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
379+
_upcast((y_p[:, None] - y_g[None, :])) ** 2
380+
)
379381
# The distance IoU is the IoU penalized by a normalized
380382
# distance between boxes' centers squared.
381383
return iou - (centers_distance_squared / diagonal_distance_squared), iou

torchvision/ops/ciou_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def complete_box_iou_loss(
1414

1515
"""
1616
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
17-
boxes do not overlap overlap area, This loss function considers important geometrical
18-
factors such as overlap area, normalized central point distance and aspect ratio.
17+
boxes do not overlap. This loss function considers important geometrical
18+
factors such as overlap area, normalized central point distance and aspect ratio.
1919
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
2020
2121
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
@@ -35,7 +35,7 @@ def complete_box_iou_loss(
3535
Tensor: Loss tensor with the reduction option applied.
3636
3737
Reference:
38-
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
38+
Zhaohui Zheng et al.: Complete Intersection over Union Loss:
3939
https://arxiv.org/abs/1911.08287
4040
4141
"""

0 commit comments

Comments
 (0)