-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fix prototype transforms for (*, H, W)
segmentation masks
#6574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -8,7 +8,14 @@ | |||||||
import torch.testing | ||||||||
import torchvision.prototype.transforms.functional as F | ||||||||
from common_utils import cpu_and_gpu | ||||||||
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks | ||||||||
from prototype_common_utils import ( | ||||||||
ArgsKwargs, | ||||||||
make_bounding_boxes, | ||||||||
make_detection_and_segmentation_masks, | ||||||||
make_detection_masks, | ||||||||
make_image, | ||||||||
make_images, | ||||||||
) | ||||||||
from torch import jit | ||||||||
from torchvision.prototype import features | ||||||||
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding | ||||||||
|
@@ -55,7 +62,7 @@ def horizontal_flip_bounding_box(): | |||||||
|
||||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def horizontal_flip_segmentation_mask(): | ||||||||
for mask in make_segmentation_masks(): | ||||||||
for mask in make_detection_and_segmentation_masks(): | ||||||||
yield ArgsKwargs(mask) | ||||||||
|
||||||||
|
||||||||
|
@@ -73,7 +80,7 @@ def vertical_flip_bounding_box(): | |||||||
|
||||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def vertical_flip_segmentation_mask(): | ||||||||
for mask in make_segmentation_masks(): | ||||||||
for mask in make_detection_and_segmentation_masks(): | ||||||||
yield ArgsKwargs(mask) | ||||||||
|
||||||||
|
||||||||
|
@@ -118,7 +125,7 @@ def resize_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def resize_segmentation_mask(): | ||||||||
for mask, max_size in itertools.product( | ||||||||
make_segmentation_masks(), | ||||||||
make_detection_and_segmentation_masks(), | ||||||||
[None, 34], # max_size | ||||||||
): | ||||||||
height, width = mask.shape[-2:] | ||||||||
|
@@ -173,7 +180,7 @@ def affine_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def affine_segmentation_mask(): | ||||||||
for mask, angle, translate, scale, shear in itertools.product( | ||||||||
make_segmentation_masks(), | ||||||||
make_detection_and_segmentation_masks(), | ||||||||
[-87, 15, 90], # angle | ||||||||
[5, -5], # translate | ||||||||
[0.77, 1.27], # scale | ||||||||
|
@@ -226,7 +233,7 @@ def rotate_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def rotate_segmentation_mask(): | ||||||||
for mask, angle, expand, center in itertools.product( | ||||||||
make_segmentation_masks(), | ||||||||
make_detection_and_segmentation_masks(), | ||||||||
[-87, 15, 90], # angle | ||||||||
[True, False], # expand | ||||||||
[None, [12, 23]], # center | ||||||||
|
@@ -269,7 +276,7 @@ def crop_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def crop_segmentation_mask(): | ||||||||
for mask, top, left, height, width in itertools.product( | ||||||||
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20] | ||||||||
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20] | ||||||||
): | ||||||||
yield ArgsKwargs( | ||||||||
mask, | ||||||||
|
@@ -307,7 +314,7 @@ def resized_crop_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def resized_crop_segmentation_mask(): | ||||||||
for mask, top, left, height, width, size in itertools.product( | ||||||||
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] | ||||||||
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] | ||||||||
): | ||||||||
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size) | ||||||||
|
||||||||
|
@@ -326,7 +333,7 @@ def pad_image_tensor(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def pad_segmentation_mask(): | ||||||||
for mask, padding, padding_mode in itertools.product( | ||||||||
make_segmentation_masks(), | ||||||||
make_detection_and_segmentation_masks(), | ||||||||
[[1], [1, 1], [1, 1, 2, 2]], # padding | ||||||||
["constant", "symmetric", "edge", "reflect"], # padding mode, | ||||||||
): | ||||||||
|
@@ -374,7 +381,7 @@ def perspective_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def perspective_segmentation_mask(): | ||||||||
for mask, perspective_coeffs in itertools.product( | ||||||||
make_segmentation_masks(extra_dims=((), (4,))), | ||||||||
make_detection_and_segmentation_masks(extra_dims=((), (4,))), | ||||||||
[ | ||||||||
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], | ||||||||
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], | ||||||||
|
@@ -411,7 +418,7 @@ def elastic_bounding_box(): | |||||||
|
||||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def elastic_segmentation_mask(): | ||||||||
for mask in make_segmentation_masks(extra_dims=((), (4,))): | ||||||||
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))): | ||||||||
h, w = mask.shape[-2:] | ||||||||
displacement = torch.rand(1, h, w, 2) | ||||||||
yield ArgsKwargs( | ||||||||
|
@@ -440,7 +447,7 @@ def center_crop_bounding_box(): | |||||||
@register_kernel_info_from_sample_inputs_fn | ||||||||
def center_crop_segmentation_mask(): | ||||||||
for mask, output_size in itertools.product( | ||||||||
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), | ||||||||
make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), | ||||||||
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size | ||||||||
): | ||||||||
yield ArgsKwargs(mask, output_size) | ||||||||
|
@@ -771,7 +778,8 @@ def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): | |||||||
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x] | ||||||||
return expected_mask.to(mask.device) | ||||||||
|
||||||||
for mask in make_segmentation_masks(extra_dims=((), (4,))): | ||||||||
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. | ||||||||
for mask in make_detection_masks(extra_dims=((), (4,))): | ||||||||
Comment on lines
+781
to
+782
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is due to the
Suggested change
here. This comment also applies to other occurrences of the same |
||||||||
output_mask = F.affine_segmentation_mask( | ||||||||
mask, | ||||||||
angle=angle, | ||||||||
|
@@ -1011,7 +1019,8 @@ def _compute_expected_mask(mask, angle_, expand_, center_): | |||||||
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x] | ||||||||
return expected_mask.to(mask.device) | ||||||||
|
||||||||
for mask in make_segmentation_masks(extra_dims=((), (4,))): | ||||||||
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. | ||||||||
for mask in make_detection_masks(extra_dims=((), (4,))): | ||||||||
output_mask = F.rotate_segmentation_mask( | ||||||||
mask, | ||||||||
angle=angle, | ||||||||
|
@@ -1138,7 +1147,7 @@ def _compute_expected_mask(mask, top_, left_, height_, width_): | |||||||
|
||||||||
return expected | ||||||||
|
||||||||
for mask in make_segmentation_masks(): | ||||||||
for mask in make_detection_and_segmentation_masks(): | ||||||||
if mask.device != torch.device(device): | ||||||||
mask = mask.to(device) | ||||||||
output_mask = F.crop_segmentation_mask(mask, top, left, height, width) | ||||||||
|
@@ -1358,7 +1367,7 @@ def _compute_expected_mask(mask, padding_, padding_mode_): | |||||||
|
||||||||
return output | ||||||||
|
||||||||
for mask in make_segmentation_masks(): | ||||||||
for mask in make_detection_and_segmentation_masks(): | ||||||||
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode) | ||||||||
|
||||||||
expected_mask = _compute_expected_mask(mask, padding, padding_mode) | ||||||||
|
@@ -1487,7 +1496,8 @@ def _compute_expected_mask(mask, pcoeffs_): | |||||||
|
||||||||
pcoeffs = _get_perspective_coeffs(startpoints, endpoints) | ||||||||
|
||||||||
for mask in make_segmentation_masks(extra_dims=((), (4,))): | ||||||||
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. | ||||||||
for mask in make_detection_masks(extra_dims=((), (4,))): | ||||||||
mask = mask.to(device) | ||||||||
|
||||||||
output_mask = F.perspective_segmentation_mask( | ||||||||
|
@@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s | |||||||
|
||||||||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||||||||
@pytest.mark.parametrize( | ||||||||
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)] | ||||||||
"fn, make_samples", | ||||||||
[ | ||||||||
(F.elastic_image_tensor, make_images), | ||||||||
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks. | ||||||||
(F.elastic_segmentation_mask, make_detection_masks), | ||||||||
], | ||||||||
) | ||||||||
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): | ||||||||
in_box = [10, 15, 25, 35] | ||||||||
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))): | ||||||||
c, h, w = sample.shape[-3:] | ||||||||
# Setup a dummy image with 4 points | ||||||||
print(sample.shape) | ||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c) | ||||||||
sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c) | ||||||||
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c) | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming is now quite weird. We are calling a
make_detection_mask
function and get afeatures.SegmentationMask
back. Now that we have somewhat settled on the terms segmentation and detection masks, can we maybe rename the class tofeatures.Mask
since it represents both types?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a second thought, I remember @YosuaMichael telling us that on Depth Perception they also have a sort of Valid Masks (see #6495).
@YosuaMichael could you clarify if they operate exactly like Segmentation/Detection masks and if we could just reuse this object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To give more context:
(*, H, W)
where each value corresponds to a class index(*, N, H, W)
whereN
denotes the number of objects and the values are otherwise "boolean"In both cases the
*
in the shape denotes arbitrary batch dimensionsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox I think the valid_mask on depth perception will behave like detection_mask with 2 object (it is boolean, but can be implemented with 0/1 uint as well). So overall I think we can reuse it.