Skip to content

Commit 3a2631b

Browse files
federicopozzi33Federico Pozzivfdev-5pmeier
authored
feat: add functional center crop on mask (#5961)
* feat: add functional center crop on mask * test: add correctness center crop with random segmentation mask * test: improvements * test: improvements * Apply suggestions from code review Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Federico Pozzi <[email protected]> Co-authored-by: vfdev <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent 49496c4 commit 3a2631b

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from torch import jit
1111
from torch.nn.functional import one_hot
1212
from torchvision.prototype import features
13+
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
1314
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
1415
from torchvision.transforms.functional import _get_perspective_coeffs
1516
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1617

17-
1818
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
1919

2020

@@ -421,6 +421,14 @@ def center_crop_bounding_box():
421421
)
422422

423423

424+
def center_crop_segmentation_mask():
425+
for mask, output_size in itertools.product(
426+
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
427+
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
428+
):
429+
yield SampleInput(mask, output_size)
430+
431+
424432
@pytest.mark.parametrize(
425433
"kernel",
426434
[
@@ -1337,3 +1345,26 @@ def _compute_expected_bbox(bbox, output_size_):
13371345
else:
13381346
expected_bboxes = expected_bboxes[0]
13391347
torch.testing.assert_close(output_boxes, expected_bboxes)
1348+
1349+
1350+
@pytest.mark.parametrize("device", cpu_and_gpu())
1351+
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
1352+
def test_correctness_center_crop_segmentation_mask(device, output_size):
1353+
def _compute_expected_segmentation_mask(mask, output_size):
1354+
crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]
1355+
1356+
_, image_height, image_width = mask.shape
1357+
if crop_width > image_height or crop_height > image_width:
1358+
padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
1359+
mask = F.pad_image_tensor(mask, padding, fill=0)
1360+
1361+
left = round((image_width - crop_width) * 0.5)
1362+
top = round((image_height - crop_height) * 0.5)
1363+
1364+
return mask[:, top : top + crop_height, left : left + crop_width]
1365+
1366+
mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
1367+
actual = F.center_crop_segmentation_mask(mask, output_size)
1368+
1369+
expected = _compute_expected_segmentation_mask(mask, output_size)
1370+
torch.testing.assert_close(expected, actual)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
resize_image_pil,
4747
resize_segmentation_mask,
4848
center_crop_bounding_box,
49+
center_crop_segmentation_mask,
4950
center_crop_image_tensor,
5051
center_crop_image_pil,
5152
resized_crop_bounding_box,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,10 @@ def center_crop_bounding_box(
630630
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
631631

632632

633+
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
634+
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
635+
636+
633637
def resized_crop_image_tensor(
634638
img: torch.Tensor,
635639
top: int,

0 commit comments

Comments
 (0)