Skip to content

Commit 93fce27

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] add prototype transforms that use the prototype dispatchers (#5418)
Summary: * add prototype transforms that use the prototype dispatchers Reviewed By: jdsgomes Differential Revision: D34475309 fbshipit-source-id: 8366d044b8e118c7c360bfd6d828ef87c3055ced
1 parent 023ca7f commit 93fce27

21 files changed

+1310
-114
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
9999
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
100100
)
101101

102-
@pytest.mark.xfail
103102
@parametrize_dataset_mocks(DATASET_MOCKS)
104103
def test_transformable(self, test_home, dataset_mock, config):
105104
dataset_mock.prepare(test_home, config)

test/test_prototype_transforms.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import itertools
2+
3+
import PIL.Image
4+
import pytest
5+
import torch
6+
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
7+
from torchvision.prototype import transforms, features
8+
from torchvision.transforms.functional import to_pil_image
9+
10+
11+
def make_vanilla_tensor_images(*args, **kwargs):
12+
for image in make_images(*args, **kwargs):
13+
if image.ndim > 3:
14+
continue
15+
yield image.data
16+
17+
18+
def make_pil_images(*args, **kwargs):
19+
for image in make_vanilla_tensor_images(*args, **kwargs):
20+
yield to_pil_image(image)
21+
22+
23+
def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
24+
for bounding_box in make_bounding_boxes(*args, **kwargs):
25+
yield bounding_box.data
26+
27+
28+
INPUT_CREATIONS_FNS = {
29+
features.Image: make_images,
30+
features.BoundingBox: make_bounding_boxes,
31+
features.OneHotLabel: make_one_hot_labels,
32+
torch.Tensor: make_vanilla_tensor_images,
33+
PIL.Image.Image: make_pil_images,
34+
}
35+
36+
37+
def parametrize(transforms_with_inputs):
38+
return pytest.mark.parametrize(
39+
("transform", "input"),
40+
[
41+
pytest.param(
42+
transform,
43+
input,
44+
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
45+
)
46+
for transform, inputs in transforms_with_inputs
47+
for idx, input in enumerate(inputs)
48+
],
49+
)
50+
51+
52+
def parametrize_from_transforms(*transforms):
53+
transforms_with_inputs = []
54+
for transform in transforms:
55+
dispatcher = transform._DISPATCHER
56+
if dispatcher is None:
57+
continue
58+
59+
for type_ in dispatcher._kernels:
60+
try:
61+
inputs = INPUT_CREATIONS_FNS[type_]()
62+
except KeyError:
63+
continue
64+
65+
transforms_with_inputs.append((transform, inputs))
66+
67+
return parametrize(transforms_with_inputs)
68+
69+
70+
class TestSmoke:
71+
@parametrize_from_transforms(
72+
transforms.RandomErasing(),
73+
transforms.HorizontalFlip(),
74+
transforms.Resize([16, 16]),
75+
transforms.CenterCrop([16, 16]),
76+
transforms.ConvertImageDtype(),
77+
)
78+
def test_common(self, transform, input):
79+
transform(input)
80+
81+
@parametrize(
82+
[
83+
(
84+
transform,
85+
[
86+
dict(
87+
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float),
88+
one_hot_label=features.OneHotLabel.new_like(
89+
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float
90+
),
91+
)
92+
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
93+
],
94+
)
95+
for transform in [
96+
transforms.RandomMixup(alpha=1.0),
97+
transforms.RandomCutmix(alpha=1.0),
98+
]
99+
]
100+
)
101+
def test_mixup_cutmix(self, transform, input):
102+
transform(input)
103+
104+
@parametrize(
105+
[
106+
(
107+
transform,
108+
itertools.chain.from_iterable(
109+
fn(dtypes=[torch.uint8], extra_dims=[(4,)])
110+
for fn in [
111+
make_images,
112+
make_vanilla_tensor_images,
113+
make_pil_images,
114+
]
115+
),
116+
)
117+
for transform in (
118+
transforms.RandAugment(),
119+
transforms.TrivialAugmentWide(),
120+
transforms.AutoAugment(),
121+
)
122+
]
123+
)
124+
def test_auto_augment(self, transform, input):
125+
transform(input)
126+
127+
@parametrize(
128+
[
129+
(
130+
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
131+
itertools.chain.from_iterable(
132+
fn(color_spaces=["rgb"], dtypes=[torch.float32])
133+
for fn in [
134+
make_images,
135+
make_vanilla_tensor_images,
136+
]
137+
),
138+
),
139+
]
140+
)
141+
def test_normalize(self, transform, input):
142+
transform(input)
143+
144+
@parametrize(
145+
[
146+
(
147+
transforms.ConvertColorSpace("grayscale"),
148+
itertools.chain(
149+
make_images(),
150+
make_vanilla_tensor_images(color_spaces=["rgb"]),
151+
make_pil_images(color_spaces=["rgb"]),
152+
),
153+
)
154+
]
155+
)
156+
def test_convert_bounding_color_space(self, transform, input):
157+
transform(input)
158+
159+
@parametrize(
160+
[
161+
(
162+
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
163+
itertools.chain(
164+
make_bounding_boxes(),
165+
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
166+
),
167+
)
168+
]
169+
)
170+
def test_convert_bounding_box_format(self, transform, input):
171+
transform(input)
172+
173+
@parametrize(
174+
[
175+
(
176+
transforms.RandomResizedCrop([16, 16]),
177+
itertools.chain(
178+
make_images(extra_dims=[(4,)]),
179+
make_vanilla_tensor_images(),
180+
make_pil_images(),
181+
),
182+
)
183+
]
184+
)
185+
def test_random_resized_crop(self, transform, input):
186+
transform(input)

test/test_prototype_transforms_kernels.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.testing
66
import torchvision.prototype.transforms.kernels as K
77
from torch import jit
8+
from torch.nn.functional import one_hot
89
from torchvision.prototype import features
910

1011
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
@@ -39,10 +40,10 @@ def make_images(
3940
extra_dims=((4,), (2, 3)),
4041
):
4142
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
42-
yield make_image(size, color_space=color_space)
43+
yield make_image(size, color_space=color_space, dtype=dtype)
4344

44-
for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims):
45-
yield make_image(color_space=color_space, extra_dims=extra_dims_)
45+
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
46+
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
4647

4748

4849
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
@@ -106,6 +107,27 @@ def make_bounding_boxes(
106107
yield make_bounding_box(format=format, extra_dims=extra_dims_)
107108

108109

110+
def make_label(size=(), *, categories=("category0", "category1")):
111+
return features.Label(torch.randint(0, len(categories) if categories else 10, size), categories=categories)
112+
113+
114+
def make_one_hot_label(*args, **kwargs):
115+
label = make_label(*args, **kwargs)
116+
return features.OneHotLabel(one_hot(label, num_classes=len(label.categories)), categories=label.categories)
117+
118+
119+
def make_one_hot_labels(
120+
*,
121+
num_categories=(1, 2, 10),
122+
extra_dims=((4,), (2, 3)),
123+
):
124+
for num_categories_ in num_categories:
125+
yield make_one_hot_label(categories=[f"category{idx}" for idx in range(num_categories_)])
126+
127+
for extra_dims_ in extra_dims:
128+
yield make_one_hot_label(extra_dims_)
129+
130+
109131
class SampleInput:
110132
def __init__(self, *args, **kwargs):
111133
self.args = args
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
12
from . import kernels # usort: skip
23
from . import functional # usort: skip
3-
from .kernels import InterpolationMode # usort: skip
4+
from ._transform import Transform # usort: skip
45

6+
from ._augment import RandomErasing, RandomMixup, RandomCutmix
7+
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
8+
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
9+
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
10+
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
11+
from ._misc import Identity, Normalize, ToDtype, Lambda
512
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
13+
from ._type_conversion import DecodeImage, LabelToOneHot

0 commit comments

Comments
 (0)