diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 3e63316ddd2..bb4540d1aae 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict, Optional, Union import torch import torchvision @@ -401,3 +401,39 @@ def forward(self, img, target=None): img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom]) return img, target + + +class RandomShortestSize(nn.Module): + def __init__( + self, + min_size: Union[List[int], Tuple[int], int], + max_size: int, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) + self.max_size = max_size + self.interpolation = interpolation + + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + _, orig_height, orig_width = F.get_dimensions(image) + + min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] + r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) + + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + + if target is not None: + target["boxes"][:, 0::2] *= new_width / orig_width + target["boxes"][:, 1::2] *= new_height / orig_height + if "masks" in target: + target["masks"] = F.resize( + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST + ) + + return image, target