Skip to content

Commit 165f35b

Browse files
NicolasHugpmeier
authored andcommitted
[fbsync] [proto] Ported RandomIoUCrop from detection refs (#6401)
Summary: * [proto] Ported RandomIoUCrop from detection refs * Scope acceptable data types * Added get_params test * Added test__transform_empty_params * Added support for OneHotLabel and tests * Added tests for mask * Updated error message * Apply suggestions from code review * Added support for OHE masks and tests * Ignored mypy error * Fixed forward call on sample * Added a todo Reviewed By: datumbox Differential Revision: D39013670 fbshipit-source-id: 47101bf538c88a5905ab6e5f4c5984271b8c57cd Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent ad1ff33 commit 165f35b

File tree

6 files changed

+253
-2
lines changed

6 files changed

+253
-2
lines changed

test/test_prototype_transforms.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88
import torch
9-
from common_utils import assert_equal
9+
from common_utils import assert_equal, cpu_and_gpu
1010
from test_prototype_transforms_functional import (
1111
make_bounding_box,
1212
make_bounding_boxes,
@@ -15,6 +15,7 @@
1515
make_one_hot_labels,
1616
make_segmentation_mask,
1717
)
18+
from torchvision.ops.boxes import box_iou
1819
from torchvision.prototype import features, transforms
1920
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
2021

@@ -1127,6 +1128,124 @@ def test_ctor(self, trfms):
11271128
assert isinstance(output, torch.Tensor)
11281129

11291130

1131+
class TestRandomIoUCrop:
1132+
@pytest.mark.parametrize("device", cpu_and_gpu())
1133+
@pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
1134+
def test__get_params(self, device, options, mocker):
1135+
image = mocker.MagicMock(spec=features.Image)
1136+
image.num_channels = 3
1137+
image.image_size = (24, 32)
1138+
bboxes = features.BoundingBox(
1139+
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
1140+
format="XYXY",
1141+
image_size=image.image_size,
1142+
device=device,
1143+
)
1144+
sample = [image, bboxes]
1145+
1146+
transform = transforms.RandomIoUCrop(sampler_options=options)
1147+
1148+
n_samples = 5
1149+
for _ in range(n_samples):
1150+
1151+
params = transform._get_params(sample)
1152+
1153+
if options == [2.0]:
1154+
assert len(params) == 0
1155+
return
1156+
1157+
assert len(params["is_within_crop_area"]) > 0
1158+
assert params["is_within_crop_area"].dtype == torch.bool
1159+
1160+
orig_h = image.image_size[0]
1161+
orig_w = image.image_size[1]
1162+
assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
1163+
assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)
1164+
1165+
left, top = params["left"], params["top"]
1166+
new_h, new_w = params["height"], params["width"]
1167+
ious = box_iou(
1168+
bboxes,
1169+
torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device),
1170+
)
1171+
assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}"
1172+
1173+
def test__transform_empty_params(self, mocker):
1174+
transform = transforms.RandomIoUCrop(sampler_options=[2.0])
1175+
image = features.Image(torch.rand(1, 3, 4, 4))
1176+
bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4))
1177+
label = features.Label(torch.tensor([1]))
1178+
sample = [image, bboxes, label]
1179+
# Let's mock transform._get_params to control the output:
1180+
transform._get_params = mocker.MagicMock(return_value={})
1181+
output = transform(sample)
1182+
torch.testing.assert_close(output, sample)
1183+
1184+
def test_forward_assertion(self):
1185+
transform = transforms.RandomIoUCrop()
1186+
with pytest.raises(
1187+
TypeError,
1188+
match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels",
1189+
):
1190+
transform(torch.tensor(0))
1191+
1192+
def test__transform(self, mocker):
1193+
transform = transforms.RandomIoUCrop()
1194+
1195+
image = features.Image(torch.rand(3, 32, 24))
1196+
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
1197+
label = features.Label(torch.randint(0, 10, size=(6,)))
1198+
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
1199+
masks = make_segmentation_mask((32, 24))
1200+
ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24)))
1201+
sample = [image, bboxes, label, ohe_label, masks, ohe_masks]
1202+
1203+
fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x)
1204+
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
1205+
1206+
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
1207+
transform._get_params = mocker.MagicMock(return_value=params)
1208+
output = transform(sample)
1209+
1210+
assert fn.call_count == 4
1211+
1212+
expected_calls = [
1213+
mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
1214+
mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
1215+
mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
1216+
mocker.call(
1217+
ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
1218+
),
1219+
]
1220+
1221+
fn.assert_has_calls(expected_calls)
1222+
1223+
expected_within_targets = sum(is_within_crop_area)
1224+
1225+
# check number of bboxes vs number of labels:
1226+
output_bboxes = output[1]
1227+
assert isinstance(output_bboxes, features.BoundingBox)
1228+
assert len(output_bboxes) == expected_within_targets
1229+
1230+
# check labels
1231+
output_label = output[2]
1232+
assert isinstance(output_label, features.Label)
1233+
assert len(output_label) == expected_within_targets
1234+
torch.testing.assert_close(output_label, label[is_within_crop_area])
1235+
1236+
output_ohe_label = output[3]
1237+
assert isinstance(output_ohe_label, features.OneHotLabel)
1238+
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
1239+
1240+
output_masks = output[4]
1241+
assert isinstance(output_masks, features.SegmentationMask)
1242+
assert output_masks.shape[:-2] == masks.shape[:-2]
1243+
1244+
output_ohe_masks = output[5]
1245+
assert isinstance(output_ohe_masks, features.SegmentationMask)
1246+
assert len(output_ohe_masks) == expected_within_targets
1247+
1248+
11301249
class TestScaleJitter:
11311250
def test__get_params(self, mocker):
11321251
image_size = (24, 32)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
RandomAffine,
2525
RandomCrop,
2626
RandomHorizontalFlip,
27+
RandomIoUCrop,
2728
RandomPerspective,
2829
RandomResizedCrop,
2930
RandomRotation,

torchvision/prototype/transforms/_geometry.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55

66
import PIL.Image
77
import torch
8+
from torchvision.ops.boxes import box_iou
89
from torchvision.prototype import features
910
from torchvision.prototype.transforms import functional as F, Transform
1011
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
1112
from torchvision.transforms.functional_tensor import _parse_pad_padding
1213
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size
14+
1315
from typing_extensions import Literal
1416

1517
from ._transform import _RandomApplyTransform
16-
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image
18+
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image
1719

1820

1921
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -620,6 +622,116 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
620622
)
621623

622624

625+
class RandomIoUCrop(Transform):
626+
def __init__(
627+
self,
628+
min_scale: float = 0.3,
629+
max_scale: float = 1.0,
630+
min_aspect_ratio: float = 0.5,
631+
max_aspect_ratio: float = 2.0,
632+
sampler_options: Optional[List[float]] = None,
633+
trials: int = 40,
634+
):
635+
super().__init__()
636+
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
637+
self.min_scale = min_scale
638+
self.max_scale = max_scale
639+
self.min_aspect_ratio = min_aspect_ratio
640+
self.max_aspect_ratio = max_aspect_ratio
641+
if sampler_options is None:
642+
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
643+
self.options = sampler_options
644+
self.trials = trials
645+
646+
def _get_params(self, sample: Any) -> Dict[str, Any]:
647+
648+
image = query_image(sample)
649+
_, orig_h, orig_w = get_image_dimensions(image)
650+
bboxes = query_bounding_box(sample)
651+
652+
while True:
653+
# sample an option
654+
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
655+
min_jaccard_overlap = self.options[idx]
656+
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
657+
return dict()
658+
659+
for _ in range(self.trials):
660+
# check the aspect ratio limitations
661+
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
662+
new_w = int(orig_w * r[0])
663+
new_h = int(orig_h * r[1])
664+
aspect_ratio = new_w / new_h
665+
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
666+
continue
667+
668+
# check for 0 area crops
669+
r = torch.rand(2)
670+
left = int((orig_w - new_w) * r[0])
671+
top = int((orig_h - new_h) * r[1])
672+
right = left + new_w
673+
bottom = top + new_h
674+
if left == right or top == bottom:
675+
continue
676+
677+
# check for any valid boxes with centers within the crop area
678+
xyxy_bboxes = F.convert_bounding_box_format(
679+
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
680+
)
681+
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
682+
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
683+
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
684+
if not is_within_crop_area.any():
685+
continue
686+
687+
# check at least 1 box with jaccard limitations
688+
xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
689+
ious = box_iou(
690+
xyxy_bboxes,
691+
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device),
692+
)
693+
if ious.max() < min_jaccard_overlap:
694+
continue
695+
696+
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
697+
698+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
699+
if len(params) < 1:
700+
return inpt
701+
702+
is_within_crop_area = params["is_within_crop_area"]
703+
704+
if isinstance(inpt, (features.Label, features.OneHotLabel)):
705+
return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]
706+
707+
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
708+
709+
if isinstance(output, features.BoundingBox):
710+
bboxes = output[is_within_crop_area]
711+
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
712+
output = features.BoundingBox.new_like(output, bboxes)
713+
elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1:
714+
# apply is_within_crop_area if mask is one-hot encoded
715+
masks = output[is_within_crop_area]
716+
output = features.SegmentationMask.new_like(output, masks)
717+
718+
return output
719+
720+
def forward(self, *inputs: Any) -> Any:
721+
sample = inputs if len(inputs) > 1 else inputs[0]
722+
# TODO: Allow image to be a torch.Tensor
723+
if not (
724+
has_all(sample, features.BoundingBox)
725+
and has_any(sample, PIL.Image.Image, features.Image)
726+
and has_any(sample, features.Label, features.OneHotLabel)
727+
):
728+
raise TypeError(
729+
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
730+
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
731+
)
732+
return super().forward(sample)
733+
734+
623735
class ScaleJitter(Transform):
624736
def __init__(
625737
self,

torchvision/prototype/transforms/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
1717
raise TypeError("No image was found in the sample")
1818

1919

20+
def query_bounding_box(sample: Any) -> features.BoundingBox:
21+
flat_sample, _ = tree_flatten(sample)
22+
for i in flat_sample:
23+
if isinstance(i, features.BoundingBox):
24+
return i
25+
26+
raise TypeError("No bounding box was found in the sample")
27+
28+
2029
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
2130
if isinstance(image, features.Image):
2231
channels = image.num_channels

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22
from ._meta import (
3+
clamp_bounding_box,
34
convert_bounding_box_format,
45
convert_color_space_image_tensor,
56
convert_color_space_image_pil,

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def convert_bounding_box_format(
6161
return bounding_box
6262

6363

64+
def clamp_bounding_box(
65+
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
66+
) -> torch.Tensor:
67+
xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY)
68+
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
69+
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
70+
return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)
71+
72+
6473
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
6574
return image[..., :-1, :, :], image[..., -1:, :, :]
6675

0 commit comments

Comments
 (0)