Skip to content

[proto] Added consistency tests for detection transforms #6566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 15, 2022
91 changes: 89 additions & 2 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import enum
import inspect
from importlib.machinery import SourceFileLoader
from pathlib import Path

import numpy as np
import PIL.Image
import pytest

import torch
from prototype_common_utils import ArgsKwargs, assert_equal, make_images
from prototype_common_utils import (
ArgsKwargs,
assert_equal,
make_bounding_box,
make_detection_mask,
make_image,
make_images,
make_label,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.functional import to_image_pil, to_pil_image


DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
Expand Down Expand Up @@ -840,3 +850,80 @@ def test_aa(self, inpt, interpolation):
output = t(inpt)

assert_equal(expected_output, output)


# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()


class TestRefDetTransforms:
def make_datapoints(self, with_mask=True):
size = (600, 800)
num_objects = 22

pil_image = to_pil_image(make_image(size=size, color_space=features.ColorSpace.RGB))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(size=(num_objects,)),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

yield (pil_image, target)

tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(size=(num_objects,)),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

yield (tensor_image, target)

feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(size=(num_objects,)),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

yield (feature_image, target)

@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
[
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}),
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}),
(
det_transforms.FixedSizeCrop((1024, 1024), fill=0),
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
{},
),
(
det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
prototype_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
{},
),
],
)
def test_transform(self, t_ref, t, data_kwargs):
for dp in self.make_datapoints(**data_kwargs):

# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
output = t(dp)

torch.manual_seed(12)
expected_output = t_ref(*dp)

assert_equal(expected_output, output)