-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fix prototype transforms for (*, H, W)
segmentation masks
#6574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -178,22 +178,58 @@ def make_one_hot_labels( | |||
yield make_one_hot_label(extra_dims_) | |||
|
|||
|
|||
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8): | |||
def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming is now quite weird. We are calling a make_detection_mask
function and get a features.SegmentationMask
back. Now that we have somewhat settled on the terms segmentation and detection masks, can we maybe rename the class to features.Mask
since it represents both types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a second thought, I remember @YosuaMichael telling us that on Depth Perception they also have a sort of Valid Masks (see #6495).
@YosuaMichael could you clarify if they operate exactly like Segmentation/Detection masks and if we could just reuse this object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To give more context:
- segmentation mask is of shape
(*, H, W)
where each value corresponds to a class index - detection mask is of shape
(*, N, H, W)
whereN
denotes the number of objects and the values are otherwise "boolean"
In both cases the *
in the shape denotes arbitrary batch dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox I think the valid_mask on depth perception will behave like detection_mask with 2 object (it is boolean, but can be implemented with 0/1 uint as well). So overall I think we can reuse it.
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. | ||
for mask in make_detection_masks(extra_dims=((), (4,))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is due to the assert mask.ndim == 3
. This was introduced in #6546 to make sure only detection masks are computed here. @vfdev-5 Could you fix the expected mask computation so both types of masks are supported? If that is done, we should use
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. | |
for mask in make_detection_masks(extra_dims=((), (4,))): | |
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))): |
here.
This comment also applies to other occurrences of the same FIXME
comment.
angle: float, | ||
translate: List[float], | ||
scale: float, | ||
shear: List[float], | ||
center: Optional[List[float]] = None, | ||
) -> torch.Tensor: | ||
return affine_image_tensor( | ||
mask, | ||
if segmentation_mask.ndim < 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love to do that in a decorator, but torch.jit.script
doesn't support it 😢
def handle_segmentation_mask(mask_kernel):
@functools.wraps(mask_kernel)
def wrapper(mask, *other_args, **kwargs):
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = mask_kernel(mask, *other_args, **kwargs)
if needs_squeeze:
output = output.squeeze(0)
return output
return wrapper
@handle_segmentation_mask
def affine_segmentation_mask(mask, ...):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The story of our lives...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes on the functional looks good to me.
angle: float, | ||
translate: List[float], | ||
scale: float, | ||
shear: List[float], | ||
center: Optional[List[float]] = None, | ||
) -> torch.Tensor: | ||
return affine_image_tensor( | ||
mask, | ||
if segmentation_mask.ndim < 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The story of our lives...
Test failure is unrelated:
|
…6574) Summary: * add generator functions for segmentation masks * update functional tests and fix kernels * fix transforms tests Reviewed By: jdsgomes Differential Revision: D39543270 fbshipit-source-id: 6f56ba288277ce0dd98115a794b2002d5341e83f
Most of our
*_segmentation_mask
kernels simply dispatch to the*_image_tensor
kernel. However a segmentation mask (from the segmentation task) might have no channel dimension, i.e. shape(*, H, W)
. Passing this directly into the image kernel will fail for some that require this dimension to be there. For examplevision/torchvision/prototype/transforms/functional/_geometry.py
Line 566 in cac4e22
This was not detected before because the CI only tested against "detection" masks, i.e.
(*; N, H, W)
whereN
denotes the number of objects.All kernels work the same, but in case the input has no channel dimension, we
.unsqueeze(0)
before calling the image kernel and.squeeze(0)
afterwards.