|
1 | 1 | import enum
|
2 | 2 | import inspect
|
| 3 | +from importlib.machinery import SourceFileLoader |
| 4 | +from pathlib import Path |
3 | 5 |
|
4 | 6 | import numpy as np
|
5 | 7 | import PIL.Image
|
6 | 8 | import pytest
|
7 | 9 |
|
8 | 10 | import torch
|
9 |
| -from prototype_common_utils import ArgsKwargs, assert_equal, make_images |
| 11 | +from prototype_common_utils import ( |
| 12 | + ArgsKwargs, |
| 13 | + assert_equal, |
| 14 | + make_bounding_box, |
| 15 | + make_detection_mask, |
| 16 | + make_image, |
| 17 | + make_images, |
| 18 | + make_label, |
| 19 | +) |
10 | 20 | from torchvision import transforms as legacy_transforms
|
11 | 21 | from torchvision._utils import sequence_to_str
|
12 | 22 | from torchvision.prototype import features, transforms as prototype_transforms
|
@@ -840,3 +850,80 @@ def test_aa(self, inpt, interpolation):
|
840 | 850 | output = t(inpt)
|
841 | 851 |
|
842 | 852 | assert_equal(expected_output, output)
|
| 853 | + |
| 854 | + |
| 855 | +# Import reference detection transforms here for consistency checks |
| 856 | +# torchvision/references/detection/transforms.py |
| 857 | +ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" |
| 858 | +det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() |
| 859 | + |
| 860 | + |
| 861 | +class TestRefDetTransforms: |
| 862 | + def make_datapoints(self, with_mask=True): |
| 863 | + size = (600, 800) |
| 864 | + num_objects = 22 |
| 865 | + |
| 866 | + pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) |
| 867 | + target = { |
| 868 | + "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), |
| 869 | + "labels": make_label(extra_dims=(num_objects,), categories=80), |
| 870 | + } |
| 871 | + if with_mask: |
| 872 | + target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) |
| 873 | + |
| 874 | + yield (pil_image, target) |
| 875 | + |
| 876 | + tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) |
| 877 | + target = { |
| 878 | + "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), |
| 879 | + "labels": make_label(extra_dims=(num_objects,), categories=80), |
| 880 | + } |
| 881 | + if with_mask: |
| 882 | + target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) |
| 883 | + |
| 884 | + yield (tensor_image, target) |
| 885 | + |
| 886 | + feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) |
| 887 | + target = { |
| 888 | + "boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), |
| 889 | + "labels": make_label(extra_dims=(num_objects,), categories=80), |
| 890 | + } |
| 891 | + if with_mask: |
| 892 | + target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) |
| 893 | + |
| 894 | + yield (feature_image, target) |
| 895 | + |
| 896 | + @pytest.mark.parametrize( |
| 897 | + "t_ref, t, data_kwargs", |
| 898 | + [ |
| 899 | + (det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}), |
| 900 | + (det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}), |
| 901 | + (det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}), |
| 902 | + (det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}), |
| 903 | + ( |
| 904 | + det_transforms.FixedSizeCrop((1024, 1024), fill=0), |
| 905 | + prototype_transforms.FixedSizeCrop((1024, 1024), fill=0), |
| 906 | + {}, |
| 907 | + ), |
| 908 | + ( |
| 909 | + det_transforms.RandomShortestSize( |
| 910 | + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 |
| 911 | + ), |
| 912 | + prototype_transforms.RandomShortestSize( |
| 913 | + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 |
| 914 | + ), |
| 915 | + {}, |
| 916 | + ), |
| 917 | + ], |
| 918 | + ) |
| 919 | + def test_transform(self, t_ref, t, data_kwargs): |
| 920 | + for dp in self.make_datapoints(**data_kwargs): |
| 921 | + |
| 922 | + # We should use prototype transform first as reference transform performs inplace target update |
| 923 | + torch.manual_seed(12) |
| 924 | + output = t(dp) |
| 925 | + |
| 926 | + torch.manual_seed(12) |
| 927 | + expected_output = t_ref(*dp) |
| 928 | + |
| 929 | + assert_equal(expected_output, output) |
0 commit comments