-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proto] Ported RandomIoUCrop from detection refs #6401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bf21a25
ed229cd
19076a8
5c2275e
84d2f09
6d664e5
3085dbf
1d06ec3
abf4381
b4fe1a9
3eaacf2
6231608
0b61852
419ba8f
e862e89
f8253aa
4967ae9
79bbe76
b7a5591
1353cfe
c5c46a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,15 +5,17 @@ | |||||
|
||||||
import PIL.Image | ||||||
import torch | ||||||
from torchvision.ops.boxes import box_iou | ||||||
from torchvision.prototype import features | ||||||
from torchvision.prototype.transforms import functional as F, Transform | ||||||
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor | ||||||
from torchvision.transforms.functional_tensor import _parse_pad_padding | ||||||
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size | ||||||
|
||||||
from typing_extensions import Literal | ||||||
|
||||||
from ._transform import _RandomApplyTransform | ||||||
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image | ||||||
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image | ||||||
|
||||||
|
||||||
class RandomHorizontalFlip(_RandomApplyTransform): | ||||||
|
@@ -620,6 +622,116 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |||||
) | ||||||
|
||||||
|
||||||
class RandomIoUCrop(Transform): | ||||||
def __init__( | ||||||
self, | ||||||
min_scale: float = 0.3, | ||||||
max_scale: float = 1.0, | ||||||
min_aspect_ratio: float = 0.5, | ||||||
max_aspect_ratio: float = 2.0, | ||||||
sampler_options: Optional[List[float]] = None, | ||||||
trials: int = 40, | ||||||
): | ||||||
super().__init__() | ||||||
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 | ||||||
self.min_scale = min_scale | ||||||
self.max_scale = max_scale | ||||||
self.min_aspect_ratio = min_aspect_ratio | ||||||
self.max_aspect_ratio = max_aspect_ratio | ||||||
if sampler_options is None: | ||||||
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] | ||||||
self.options = sampler_options | ||||||
self.trials = trials | ||||||
|
||||||
def _get_params(self, sample: Any) -> Dict[str, Any]: | ||||||
|
||||||
image = query_image(sample) | ||||||
_, orig_h, orig_w = get_image_dimensions(image) | ||||||
bboxes = query_bounding_box(sample) | ||||||
|
||||||
while True: | ||||||
# sample an option | ||||||
idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) | ||||||
min_jaccard_overlap = self.options[idx] | ||||||
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option | ||||||
return dict() | ||||||
|
||||||
for _ in range(self.trials): | ||||||
# check the aspect ratio limitations | ||||||
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) | ||||||
new_w = int(orig_w * r[0]) | ||||||
new_h = int(orig_h * r[1]) | ||||||
aspect_ratio = new_w / new_h | ||||||
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): | ||||||
continue | ||||||
|
||||||
# check for 0 area crops | ||||||
r = torch.rand(2) | ||||||
left = int((orig_w - new_w) * r[0]) | ||||||
top = int((orig_h - new_h) * r[1]) | ||||||
right = left + new_w | ||||||
bottom = top + new_h | ||||||
if left == right or top == bottom: | ||||||
continue | ||||||
|
||||||
# check for any valid boxes with centers within the crop area | ||||||
xyxy_bboxes = F.convert_bounding_box_format( | ||||||
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True | ||||||
) | ||||||
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) | ||||||
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) | ||||||
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) | ||||||
if not is_within_crop_area.any(): | ||||||
continue | ||||||
|
||||||
# check at least 1 box with jaccard limitations | ||||||
xyxy_bboxes = xyxy_bboxes[is_within_crop_area] | ||||||
ious = box_iou( | ||||||
xyxy_bboxes, | ||||||
torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), | ||||||
) | ||||||
if ious.max() < min_jaccard_overlap: | ||||||
continue | ||||||
|
||||||
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) | ||||||
|
||||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | ||||||
if len(params) < 1: | ||||||
return inpt | ||||||
|
||||||
is_within_crop_area = params["is_within_crop_area"] | ||||||
|
||||||
if isinstance(inpt, (features.Label, features.OneHotLabel)): | ||||||
return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] | ||||||
|
||||||
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) | ||||||
|
||||||
if isinstance(output, features.BoundingBox): | ||||||
bboxes = output[is_within_crop_area] | ||||||
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) | ||||||
output = features.BoundingBox.new_like(output, bboxes) | ||||||
elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1: | ||||||
# apply is_within_crop_area if mask is one-hot encoded | ||||||
masks = output[is_within_crop_area] | ||||||
output = features.SegmentationMask.new_like(output, masks) | ||||||
Comment on lines
+713
to
+716
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox here is a support for one-hot encoded masks meanwhile other solutions we could decide about |
||||||
|
||||||
return output | ||||||
|
||||||
def forward(self, *inputs: Any) -> Any: | ||||||
sample = inputs if len(inputs) > 1 else inputs[0] | ||||||
# TODO: Allow image to be a torch.Tensor | ||||||
if not ( | ||||||
has_all(sample, features.BoundingBox) | ||||||
and has_any(sample, PIL.Image.Image, features.Image) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about plain tensors? For vision/torchvision/prototype/transforms/_augment.py Lines 108 to 109 in 162267c
Given that have ported this from references there is no BC constraint to allow plain tensors and PIL images. Still, it feels unnecessary restrictive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To understand your comment, you want to add support for images as Tensors and keep Image and PIL Image ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either that or remove support for PIL here. We should either support all image types or only the "new one" like we do in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the reason why CutMix, MixUp are not supporting PIL due to lack of implementation ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should support all types. Part of the API supporting them and part not is weird and will hinter adoption. |
||||||
and has_any(sample, features.Label, features.OneHotLabel) | ||||||
): | ||||||
raise TypeError( | ||||||
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " | ||||||
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." | ||||||
) | ||||||
return super().forward(sample) | ||||||
|
||||||
|
||||||
class ScaleJitter(Transform): | ||||||
def __init__( | ||||||
self, | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.