diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 82d8d0a84f0..e35981854e5 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -26,6 +26,8 @@ Operators drop_block3d generalized_box_iou generalized_box_iou_loss + distance_box_iou + distance_box_iou_loss masks_to_boxes nms ps_roi_align diff --git a/test/test_ops.py b/test/test_ops.py index c546710b271..96cfb630e8d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1258,6 +1258,97 @@ def test_giou_jit(self) -> None: self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) +class TestDistanceBoxIoU(BoxTestBase): + def _target_fn(self): + return (True, ops.distance_box_iou) + + def _generate_int_input(): + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_input(): + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), + ], + ) + def test_distance_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(test_input, dtypes, tolerance, expected) + + def test_distance_iou_jit(self): + self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) +def test_distance_iou_loss(dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack( + [box2, box2], + dim=0, + ) + box2s = torch.stack( + [box3, box4], + dim=0, + ) + + def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): + output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) + # TODO: When passing the dtype, the torch.half fails as usual. + expected_output = torch.tensor(expected_output, device=device) + tol = 1e-5 if dtype != torch.half else 1e-3 + torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) + + assert_distance_iou_loss(box1, box1, 0.0) + + assert_distance_iou_loss(box1, box2, 0.8125) + + assert_distance_iou_loss(box1, box3, 1.1923) + + assert_distance_iou_loss(box1, box4, 1.2500) + + assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") + assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) +def test_empty_distance_iou_inputs(dtype, device) -> None: + box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + + loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") + loss.backward() + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + + loss = ops.distance_box_iou_loss(box1, box2, reduction="none") + assert loss.numel() == 0, "diou_loss for two empty box should be empty" + + class TestCompleteBoxIou(BoxTestBase): def _target_fn(self) -> Tuple[bool, Callable]: return (True, ops.complete_box_iou) @@ -1676,6 +1767,7 @@ def test_ciou_loss(self, dtype, device): def assert_ciou_loss(box1, box2, expected_output, reduction="none"): output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) + # TODO: When passing the dtype, the torch.half test doesn't pass... expected_output = torch.tensor(expected_output, device=device) tol = 1e-5 if dtype != torch.half else 1e-3 torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 9d99db7125c..cd711578a6c 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -7,12 +7,14 @@ box_area, box_iou, generalized_box_iou, + distance_box_iou, complete_box_iou, masks_to_boxes, ) from .boxes import box_convert from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d +from .diou_loss import distance_box_iou_loss from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss @@ -40,6 +42,8 @@ "box_area", "box_iou", "generalized_box_iou", + "distance_box_iou", + "complete_box_iou", "roi_align", "RoIAlign", "roi_pool", @@ -58,6 +62,8 @@ "Conv3dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", + "distance_box_iou_loss", + "complete_box_iou_loss", "drop_block2d", "DropBlock2d", "drop_block3d", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 3239ba0d60a..3b994879ecf 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -359,6 +359,50 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v +def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: + """ + Return distance intersection-over-union (Jaccard index) between two sets of boxes. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + eps (float, optional): small number to prevent division by zero. Default: 1e-7 + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values + for every element in boxes1 and boxes2 + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(distance_box_iou) + + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) + + inter, union = _box_inter_union(boxes1, boxes2) + iou = inter / union + + lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] + diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps + + # centers of boxes + x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 + y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 + x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 + y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 + # The distance between boxes' centers squared. + centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) + + # The distance IoU is the IoU penalized by a normalized + # distance between boxes' centers squared. + return iou - (centers_distance_squared / diagonal_distance_squared) + + def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: """ Compute the bounding boxes around the provided masks. diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py new file mode 100644 index 00000000000..ea7ead19344 --- /dev/null +++ b/torchvision/ops/diou_loss.py @@ -0,0 +1,86 @@ +import torch + +from ..utils import _log_api_usage_once +from .boxes import _upcast + + +def distance_box_iou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Gradient-friendly IoU loss with an additional penalty that is non-zero when the + distance between boxes' centers isn't zero. Indeed, for two exactly overlapping + boxes, the distance IoU is the same as the IoU loss. + This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the + same dimensions. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[N, 4]): second set of boxes + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be + applied to the output. ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'`` + eps (float, optional): small number to prevent division by zero. Default: 1e-7 + + Returns: + Tensor: Loss tensor with the reduction option applied. + + Reference: + Zhaohui Zheng et. al: Distance Intersection over Union Loss: + https://arxiv.org/abs/1911.08287 + """ + + # Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(distance_box_iou_loss) + + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsct = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps + iou = intsct / union + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + # The diagonal distance of the smallest enclosing box squared + diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps + + # centers of boxes + x_p = (x2 + x1) / 2 + y_p = (y2 + y1) / 2 + x_g = (x1g + x2g) / 2 + y_g = (y1g + y2g) / 2 + # The distance between boxes' centers squared. + centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) + + # The distance IoU is the IoU penalized by a normalized + # distance between boxes' centers squared. + loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared) + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + return loss diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index c43a788063e..4d6f946f5e8 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -36,7 +36,7 @@ def generalized_box_iou_loss( ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be applied to the output. ``'mean'``: The output will be averaged. ``'sum'``: The output will be summed. Default: ``'none'`` - eps (float, optional): small number to prevent division by zero. Default: 1e-7 + eps (float): small number to prevent division by zero. Default: 1e-7 Reference: Hamid Rezatofighi et. al: Generalized Intersection over Union: