Skip to content

Add box_area_center and box_iou_center functions for cxcywh format with tests #8992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ These utility functions perform various operations on bounding boxes.
:template: function.rst

box_area
box_area_center
box_convert
box_iou
box_iou_center
clip_boxes_to_image
complete_box_iou
distance_box_iou
Expand Down
102 changes: 102 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,41 @@ def test_box_area_jit(self):
torch.testing.assert_close(scripted_area, expected)


class TestBoxAreaCenter:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area_center(box)
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)

@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
def test_int_boxes(self, dtype):
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype),
in_fmt="xyxy", out_fmt="cxcywh")
expected = torch.tensor([10000, 0], dtype=torch.int32)
self.area_check(box_tensor, expected)

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_float_boxes(self, dtype):
box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh")
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
self.area_check(box_tensor, expected)

def test_float16_box(self):
box_tensor = ops.box_convert(torch.tensor(
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
), in_fmt="xyxy", out_fmt="cxcywh")

expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
self.area_check(box_tensor, expected, atol=0.01)

def test_box_area_jit(self):
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float),
in_fmt="xyxy", out_fmt="cxcywh")
expected = ops.box_area_center(box_tensor)
scripted_fn = torch.jit.script(ops.box_area_center)
scripted_area = scripted_fn(box_tensor)
torch.testing.assert_close(scripted_area, expected)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
Expand All @@ -1459,6 +1494,14 @@ def test_box_area_jit(self):
[279.2440, 197.9812, 1189.4746, 849.2019],
]

INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]]
INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]]
FLOAT_BOXES_CXCYWH = [
[739.4324, 518.5154, 908.1572, 665.8793],
[738.8228, 519.9021, 907.3512, 662.3295],
[734.3593, 523.5916, 910.2306, 651.2207]
]


def gen_box(size, dtype=torch.float):
xy1 = torch.rand((size, 2), dtype=dtype)
Expand Down Expand Up @@ -1525,6 +1568,65 @@ def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)


class TestIouCenterBase:
@staticmethod
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box1, actual_box2)
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, actual_box: List):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected)

@staticmethod
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
N = boxes1.size(0)
M = boxes2.size(0)
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh")
boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh")
a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
torch.testing.assert_close(a, b)


class TestBoxIouCenter(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou_center)


class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
Expand Down
4 changes: 4 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from .boxes import (
batched_nms,
box_area,
box_area_center,
box_convert,
box_iou,
box_iou_center,
clip_boxes_to_image,
complete_box_iou,
distance_box_iou,
Expand Down Expand Up @@ -40,7 +42,9 @@
"clip_boxes_to_image",
"box_convert",
"box_area",
"box_area_center",
"box_iou",
"box_iou_center",
"generalized_box_iou",
"distance_box_iou",
"complete_box_iou",
Expand Down
55 changes: 55 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,25 @@ def box_area(boxes: Tensor) -> Tensor:
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


def box_area_center(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by their
(cx, cy, w, h) coordinates.

Args:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (cx, cy, w, h) format with
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.

Returns:
Tensor[N]: the area for each box
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_area)
boxes = _upcast(boxes)
return boxes[:, 2] * boxes[:, 3]


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -329,6 +348,42 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
return iou


def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
area1 = box_area_center(boxes1)
area2 = box_area_center(boxes2)

lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]

wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

union = area1[:, None] + area2 - inter

return inter, union


def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Return intersection-over-union (Jaccard index) between two sets of boxes.

Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with
``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_iou_center)
inter, union = _box_inter_union_center(boxes1, boxes2)
iou = inter / union
return iou


# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
"""
Expand Down