8
8
import torch .testing
9
9
import torchvision .prototype .transforms .functional as F
10
10
from common_utils import cpu_and_gpu
11
- from prototype_common_utils import ArgsKwargs , make_bounding_boxes , make_image , make_images , make_segmentation_masks
11
+ from prototype_common_utils import (
12
+ ArgsKwargs ,
13
+ make_bounding_boxes ,
14
+ make_detection_and_segmentation_masks ,
15
+ make_detection_masks ,
16
+ make_image ,
17
+ make_images ,
18
+ )
12
19
from torch import jit
13
20
from torchvision .prototype import features
14
21
from torchvision .prototype .transforms .functional ._geometry import _center_crop_compute_padding
@@ -55,7 +62,7 @@ def horizontal_flip_bounding_box():
55
62
56
63
@register_kernel_info_from_sample_inputs_fn
57
64
def horizontal_flip_segmentation_mask ():
58
- for mask in make_segmentation_masks ():
65
+ for mask in make_detection_and_segmentation_masks ():
59
66
yield ArgsKwargs (mask )
60
67
61
68
@@ -73,7 +80,7 @@ def vertical_flip_bounding_box():
73
80
74
81
@register_kernel_info_from_sample_inputs_fn
75
82
def vertical_flip_segmentation_mask ():
76
- for mask in make_segmentation_masks ():
83
+ for mask in make_detection_and_segmentation_masks ():
77
84
yield ArgsKwargs (mask )
78
85
79
86
@@ -118,7 +125,7 @@ def resize_bounding_box():
118
125
@register_kernel_info_from_sample_inputs_fn
119
126
def resize_segmentation_mask ():
120
127
for mask , max_size in itertools .product (
121
- make_segmentation_masks (),
128
+ make_detection_and_segmentation_masks (),
122
129
[None , 34 ], # max_size
123
130
):
124
131
height , width = mask .shape [- 2 :]
@@ -173,7 +180,7 @@ def affine_bounding_box():
173
180
@register_kernel_info_from_sample_inputs_fn
174
181
def affine_segmentation_mask ():
175
182
for mask , angle , translate , scale , shear in itertools .product (
176
- make_segmentation_masks (),
183
+ make_detection_and_segmentation_masks (),
177
184
[- 87 , 15 , 90 ], # angle
178
185
[5 , - 5 ], # translate
179
186
[0.77 , 1.27 ], # scale
@@ -226,7 +233,7 @@ def rotate_bounding_box():
226
233
@register_kernel_info_from_sample_inputs_fn
227
234
def rotate_segmentation_mask ():
228
235
for mask , angle , expand , center in itertools .product (
229
- make_segmentation_masks (),
236
+ make_detection_and_segmentation_masks (),
230
237
[- 87 , 15 , 90 ], # angle
231
238
[True , False ], # expand
232
239
[None , [12 , 23 ]], # center
@@ -269,7 +276,7 @@ def crop_bounding_box():
269
276
@register_kernel_info_from_sample_inputs_fn
270
277
def crop_segmentation_mask ():
271
278
for mask , top , left , height , width in itertools .product (
272
- make_segmentation_masks (), [- 8 , 0 , 9 ], [- 8 , 0 , 9 ], [12 , 20 ], [12 , 20 ]
279
+ make_detection_and_segmentation_masks (), [- 8 , 0 , 9 ], [- 8 , 0 , 9 ], [12 , 20 ], [12 , 20 ]
273
280
):
274
281
yield ArgsKwargs (
275
282
mask ,
@@ -307,7 +314,7 @@ def resized_crop_bounding_box():
307
314
@register_kernel_info_from_sample_inputs_fn
308
315
def resized_crop_segmentation_mask ():
309
316
for mask , top , left , height , width , size in itertools .product (
310
- make_segmentation_masks (), [- 8 , 0 , 9 ], [- 8 , 0 , 9 ], [12 , 20 ], [12 , 20 ], [(32 , 32 ), (16 , 18 )]
317
+ make_detection_and_segmentation_masks (), [- 8 , 0 , 9 ], [- 8 , 0 , 9 ], [12 , 20 ], [12 , 20 ], [(32 , 32 ), (16 , 18 )]
311
318
):
312
319
yield ArgsKwargs (mask , top = top , left = left , height = height , width = width , size = size )
313
320
@@ -326,7 +333,7 @@ def pad_image_tensor():
326
333
@register_kernel_info_from_sample_inputs_fn
327
334
def pad_segmentation_mask ():
328
335
for mask , padding , padding_mode in itertools .product (
329
- make_segmentation_masks (),
336
+ make_detection_and_segmentation_masks (),
330
337
[[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
331
338
["constant" , "symmetric" , "edge" , "reflect" ], # padding mode,
332
339
):
@@ -374,7 +381,7 @@ def perspective_bounding_box():
374
381
@register_kernel_info_from_sample_inputs_fn
375
382
def perspective_segmentation_mask ():
376
383
for mask , perspective_coeffs in itertools .product (
377
- make_segmentation_masks (extra_dims = ((), (4 ,))),
384
+ make_detection_and_segmentation_masks (extra_dims = ((), (4 ,))),
378
385
[
379
386
[1.2405 , 0.1772 , - 6.9113 , 0.0463 , 1.251 , - 5.235 , 0.00013 , 0.0018 ],
380
387
[0.7366 , - 0.11724 , 1.45775 , - 0.15012 , 0.73406 , 2.6019 , - 0.0072 , - 0.0063 ],
@@ -411,7 +418,7 @@ def elastic_bounding_box():
411
418
412
419
@register_kernel_info_from_sample_inputs_fn
413
420
def elastic_segmentation_mask ():
414
- for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
421
+ for mask in make_detection_and_segmentation_masks (extra_dims = ((), (4 ,))):
415
422
h , w = mask .shape [- 2 :]
416
423
displacement = torch .rand (1 , h , w , 2 )
417
424
yield ArgsKwargs (
@@ -440,7 +447,7 @@ def center_crop_bounding_box():
440
447
@register_kernel_info_from_sample_inputs_fn
441
448
def center_crop_segmentation_mask ():
442
449
for mask , output_size in itertools .product (
443
- make_segmentation_masks (sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
450
+ make_detection_and_segmentation_masks (sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 ))),
444
451
[[4 , 3 ], [42 , 70 ], [4 ]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
445
452
):
446
453
yield ArgsKwargs (mask , output_size )
@@ -771,7 +778,8 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
771
778
expected_mask [i , out_y , out_x ] = mask [i , in_y , in_x ]
772
779
return expected_mask .to (mask .device )
773
780
774
- for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
781
+ # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
782
+ for mask in make_detection_masks (extra_dims = ((), (4 ,))):
775
783
output_mask = F .affine_segmentation_mask (
776
784
mask ,
777
785
angle = angle ,
@@ -1011,7 +1019,8 @@ def _compute_expected_mask(mask, angle_, expand_, center_):
1011
1019
expected_mask [i , out_y , out_x ] = mask [i , in_y , in_x ]
1012
1020
return expected_mask .to (mask .device )
1013
1021
1014
- for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
1022
+ # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
1023
+ for mask in make_detection_masks (extra_dims = ((), (4 ,))):
1015
1024
output_mask = F .rotate_segmentation_mask (
1016
1025
mask ,
1017
1026
angle = angle ,
@@ -1138,7 +1147,7 @@ def _compute_expected_mask(mask, top_, left_, height_, width_):
1138
1147
1139
1148
return expected
1140
1149
1141
- for mask in make_segmentation_masks ():
1150
+ for mask in make_detection_and_segmentation_masks ():
1142
1151
if mask .device != torch .device (device ):
1143
1152
mask = mask .to (device )
1144
1153
output_mask = F .crop_segmentation_mask (mask , top , left , height , width )
@@ -1358,7 +1367,7 @@ def _compute_expected_mask(mask, padding_, padding_mode_):
1358
1367
1359
1368
return output
1360
1369
1361
- for mask in make_segmentation_masks ():
1370
+ for mask in make_detection_and_segmentation_masks ():
1362
1371
out_mask = F .pad_segmentation_mask (mask , padding , padding_mode = padding_mode )
1363
1372
1364
1373
expected_mask = _compute_expected_mask (mask , padding , padding_mode )
@@ -1487,7 +1496,8 @@ def _compute_expected_mask(mask, pcoeffs_):
1487
1496
1488
1497
pcoeffs = _get_perspective_coeffs (startpoints , endpoints )
1489
1498
1490
- for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
1499
+ # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
1500
+ for mask in make_detection_masks (extra_dims = ((), (4 ,))):
1491
1501
mask = mask .to (device )
1492
1502
1493
1503
output_mask = F .perspective_segmentation_mask (
@@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
1649
1659
1650
1660
@pytest .mark .parametrize ("device" , cpu_and_gpu ())
1651
1661
@pytest .mark .parametrize (
1652
- "fn, make_samples" , [(F .elastic_image_tensor , make_images ), (F .elastic_segmentation_mask , make_segmentation_masks )]
1662
+ "fn, make_samples" ,
1663
+ [
1664
+ (F .elastic_image_tensor , make_images ),
1665
+ # FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
1666
+ (F .elastic_segmentation_mask , make_detection_masks ),
1667
+ ],
1653
1668
)
1654
1669
def test_correctness_elastic_image_or_mask_tensor (device , fn , make_samples ):
1655
1670
in_box = [10 , 15 , 25 , 35 ]
1656
1671
for sample in make_samples (sizes = ((64 , 76 ),), extra_dims = ((), (4 ,))):
1657
1672
c , h , w = sample .shape [- 3 :]
1658
1673
# Setup a dummy image with 4 points
1659
- print (sample .shape )
1660
1674
sample [..., in_box [1 ], in_box [0 ]] = torch .arange (10 , 10 + c )
1661
1675
sample [..., in_box [3 ] - 1 , in_box [0 ]] = torch .arange (20 , 20 + c )
1662
1676
sample [..., in_box [3 ] - 1 , in_box [2 ] - 1 ] = torch .arange (30 , 30 + c )
0 commit comments