Skip to content

Commit 544a407

Browse files
vfdev-5pmeier
andauthored
[proto] Added consistency tests for detection transforms (#6566)
* [proto] Added consistency tests for detection transforms * Updated tests according to the review * More updates Co-authored-by: Philip Meier <[email protected]>
1 parent e0e9538 commit 544a407

File tree

1 file changed

+88
-1
lines changed

1 file changed

+88
-1
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import enum
22
import inspect
3+
from importlib.machinery import SourceFileLoader
4+
from pathlib import Path
35

46
import numpy as np
57
import PIL.Image
68
import pytest
79

810
import torch
9-
from prototype_common_utils import ArgsKwargs, assert_equal, make_images
11+
from prototype_common_utils import (
12+
ArgsKwargs,
13+
assert_equal,
14+
make_bounding_box,
15+
make_detection_mask,
16+
make_image,
17+
make_images,
18+
make_label,
19+
)
1020
from torchvision import transforms as legacy_transforms
1121
from torchvision._utils import sequence_to_str
1222
from torchvision.prototype import features, transforms as prototype_transforms
@@ -840,3 +850,80 @@ def test_aa(self, inpt, interpolation):
840850
output = t(inpt)
841851

842852
assert_equal(expected_output, output)
853+
854+
855+
# Import reference detection transforms here for consistency checks
856+
# torchvision/references/detection/transforms.py
857+
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
858+
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
859+
860+
861+
class TestRefDetTransforms:
862+
def make_datapoints(self, with_mask=True):
863+
size = (600, 800)
864+
num_objects = 22
865+
866+
pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB))
867+
target = {
868+
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
869+
"labels": make_label(extra_dims=(num_objects,), categories=80),
870+
}
871+
if with_mask:
872+
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
873+
874+
yield (pil_image, target)
875+
876+
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
877+
target = {
878+
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
879+
"labels": make_label(extra_dims=(num_objects,), categories=80),
880+
}
881+
if with_mask:
882+
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
883+
884+
yield (tensor_image, target)
885+
886+
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
887+
target = {
888+
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
889+
"labels": make_label(extra_dims=(num_objects,), categories=80),
890+
}
891+
if with_mask:
892+
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
893+
894+
yield (feature_image, target)
895+
896+
@pytest.mark.parametrize(
897+
"t_ref, t, data_kwargs",
898+
[
899+
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}),
900+
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}),
901+
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}),
902+
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}),
903+
(
904+
det_transforms.FixedSizeCrop((1024, 1024), fill=0),
905+
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
906+
{},
907+
),
908+
(
909+
det_transforms.RandomShortestSize(
910+
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
911+
),
912+
prototype_transforms.RandomShortestSize(
913+
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
914+
),
915+
{},
916+
),
917+
],
918+
)
919+
def test_transform(self, t_ref, t, data_kwargs):
920+
for dp in self.make_datapoints(**data_kwargs):
921+
922+
# We should use prototype transform first as reference transform performs inplace target update
923+
torch.manual_seed(12)
924+
output = t(dp)
925+
926+
torch.manual_seed(12)
927+
expected_output = t_ref(*dp)
928+
929+
assert_equal(expected_output, output)

0 commit comments

Comments
 (0)