Skip to content

Commit f5a4b2d

Browse files
committed
cherry-pick missing functions from pytorch#6401
1 parent 105bccd commit f5a4b2d

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing_extensions import Literal
1414

1515
from ._transform import _RandomApplyTransform
16-
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bboxes, query_image
16+
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bounding_box, query_image
1717

1818

1919
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -699,7 +699,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
699699
left = int(offset_width * r)
700700

701701
if needs_crop:
702-
bounding_boxes = query_bboxes(sample)
702+
bounding_boxes = query_bounding_box(sample)
703703
bounding_boxes = F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
704704
bounding_boxes = features.BoundingBox.new_like(
705705
bounding_boxes,

torchvision/prototype/transforms/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
1717
raise TypeError("No image was found in the sample")
1818

1919

20+
def query_bounding_box(sample: Any) -> features.BoundingBox:
21+
flat_sample, _ = tree_flatten(sample)
22+
for i in flat_sample:
23+
if isinstance(i, features.BoundingBox):
24+
return i
25+
26+
raise TypeError("No bounding box was found in the sample")
27+
28+
2029
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
2130
if isinstance(image, features.Image):
2231
channels = image.num_channels

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22
from ._meta import (
3+
clamp_bounding_box,
34
convert_bounding_box_format,
45
convert_color_space_image_tensor,
56
convert_color_space_image_pil,

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def convert_bounding_box_format(
6161
return bounding_box
6262

6363

64+
def clamp_bounding_box(
65+
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]
66+
) -> torch.Tensor:
67+
xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY)
68+
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1])
69+
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0])
70+
return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)
71+
72+
6473
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
6574
return image[..., :-1, :, :], image[..., -1:, :, :]
6675

0 commit comments

Comments
 (0)