From e1a356994053be9ed27916e0de31e907bd473070 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 28 Feb 2022 11:56:40 +0100 Subject: [PATCH 1/4] add prototype AugMix transform --- test/test_prototype_transforms.py | 1 + torchvision/prototype/transforms/__init__.py | 2 +- .../prototype/transforms/_auto_augment.py | 388 ++++++++++++------ torchvision/prototype/transforms/_utils.py | 19 +- 4 files changed, 275 insertions(+), 135 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 190867523eb..52101368800 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -114,6 +114,7 @@ def test_mixup_cutmix(self, transform, input): transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AutoAugment(), + transforms.AugMix(), ) ] ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 73235720d58..273e9657044 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -5,7 +5,7 @@ from ._transform import Transform # usort: skip from ._augment import RandomErasing, RandomMixup, RandomCutmix -from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment +from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 7eae25a681e..96398f80d44 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,13 +1,15 @@ +import functools import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.utils._internal import apply_recursively +from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import query_image +from ._utils import query_images K = TypeVar("K") V = TypeVar("V") @@ -26,27 +28,16 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any: - def dispatch( - image_tensor_kernel: Callable, - image_pil_kernel: Callable, - input: Any, - *args: Any, - **kwargs: Any, - ) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): - output = image_tensor_kernel(input, *args, **kwargs) - return features.Image.new_like(input, output) - elif isinstance(input, torch.Tensor): - return image_tensor_kernel(input, *args, **kwargs) - elif isinstance(input, PIL.Image.Image): - return image_pil_kernel(input, *args, **kwargs) - else: - return input + def _query_image(self, sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: + images = list(query_images(sample)) + if len(images) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single image, but found {len(images)}." + ) + return images[0] - image = query_image(sample) + def _parse_fill(self, sample: Any) -> Optional[List[float]]: + image = self._query_image(sample) num_channels = F.get_image_num_channels(image) fill = self.fill @@ -54,106 +45,137 @@ def dispatch( fill = [float(fill)] * num_channels elif fill is not None: fill = [float(f) for f in fill] + return fill + + def _dispatch( + self, + image_tensor_kernel: Callable, + image_pil_kernel: Callable, + input: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = image_tensor_kernel(input, *args, **kwargs) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return image_tensor_kernel(input, *args, **kwargs) + elif isinstance(input, PIL.Image.Image): + return image_pil_kernel(input, *args, **kwargs) + else: + return input + + def _apply_transform_to_item( + self, + item: Any, + transform_id: str, + magnitude: float, + interpolation: InterpolationMode, + fill: Optional[List[float]], + ) -> Any: + if transform_id == "Identity": + return item + elif transform_id == "ShearX": + return self._dispatch( + F.affine_image_tensor, + F.affine_image_pil, + item, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "ShearY": + return self._dispatch( + F.affine_image_tensor, + F.affine_image_pil, + item, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateX": + return self._dispatch( + F.affine_image_tensor, + F.affine_image_pil, + item, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateY": + return self._dispatch( + F.affine_image_tensor, + F.affine_image_pil, + item, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "Rotate": + return self._dispatch(F.rotate_image_tensor, F.rotate_image_pil, item, angle=magnitude) + elif transform_id == "Brightness": + return self._dispatch( + F.adjust_brightness_image_tensor, + F.adjust_brightness_image_pil, + item, + brightness_factor=1.0 + magnitude, + ) + elif transform_id == "Color": + return self._dispatch( + F.adjust_saturation_image_tensor, + F.adjust_saturation_image_pil, + item, + saturation_factor=1.0 + magnitude, + ) + elif transform_id == "Contrast": + return self._dispatch( + F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, item, contrast_factor=1.0 + magnitude + ) + elif transform_id == "Sharpness": + return self._dispatch( + F.adjust_sharpness_image_tensor, + F.adjust_sharpness_image_pil, + item, + sharpness_factor=1.0 + magnitude, + ) + elif transform_id == "Posterize": + return self._dispatch(F.posterize_image_tensor, F.posterize_image_pil, item, bits=int(magnitude)) + elif transform_id == "Solarize": + return self._dispatch(F.solarize_image_tensor, F.solarize_image_pil, item, threshold=magnitude) + elif transform_id == "AutoContrast": + return self._dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, item) + elif transform_id == "Equalize": + return self._dispatch(F.equalize_image_tensor, F.equalize_image_pil, item) + elif transform_id == "Invert": + return self._dispatch(F.invert_image_tensor, F.invert_image_pil, item) + else: + raise ValueError(f"No transform available for {transform_id}") - interpolation = self.interpolation - - def transform(input: Any) -> Any: - if type(input) in {features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") - elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)): - return input - - if transform_id == "Identity": - return input - elif transform_id == "ShearX": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "ShearY": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "TranslateX": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "TranslateY": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "Rotate": - return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude) - elif transform_id == "Brightness": - return dispatch( - F.adjust_brightness_image_tensor, - F.adjust_brightness_image_pil, - input, - brightness_factor=1.0 + magnitude, - ) - elif transform_id == "Color": - return dispatch( - F.adjust_saturation_image_tensor, - F.adjust_saturation_image_pil, - input, - saturation_factor=1.0 + magnitude, - ) - elif transform_id == "Contrast": - return dispatch( - F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude - ) - elif transform_id == "Sharpness": - return dispatch( - F.adjust_sharpness_image_tensor, - F.adjust_sharpness_image_pil, - input, - sharpness_factor=1.0 + magnitude, - ) - elif transform_id == "Posterize": - return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude)) - elif transform_id == "Solarize": - return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude) - elif transform_id == "AutoContrast": - return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input) - elif transform_id == "Equalize": - return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input) - elif transform_id == "Invert": - return dispatch(F.invert_image_tensor, F.invert_image_pil, input) - else: - raise ValueError(f"No transform available for {transform_id}") - - return apply_recursively(transform, sample) + def _apply_transform_to_sample(self, sample: Any, transform_id: str, magnitude: float) -> Any: + return apply_recursively( + functools.partial( + self._apply_transform_to_item, + transform_id=transform_id, + magnitude=magnitude, + interpolation=self.interpolation, + fill=self._parse_fill(sample), + ), + sample, + ) class AutoAugment(_AutoAugmentBase): @@ -277,7 +299,7 @@ def _get_policies( def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = query_image(sample) + image = self._query_image(sample) image_size = F.get_image_size(image) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -296,7 +318,7 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, transform_id, magnitude) + sample = self._apply_transform_to_sample(sample, transform_id, magnitude) return sample @@ -332,8 +354,9 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] + self._query_image(sample) - image = query_image(sample) + image = self._query_image(sample) image_size = F.get_image_size(image) for _ in range(self.num_ops): @@ -347,7 +370,7 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, transform_id, magnitude) + sample = self._apply_transform_to_sample(sample, transform_id, magnitude) return sample @@ -382,7 +405,7 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = query_image(sample) + image = self._query_image(sample) image_size = F.get_image_size(image) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -395,4 +418,115 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - return self._apply_transform(sample, transform_id, magnitude) + return self._apply_transform_to_sample(sample, transform_id, magnitude) + + +class AugMix(_AutoAugmentBase): + _PARTIAL_AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": ( + lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + .round() + .int(), + False, + ), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + } + _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = { + **_PARTIAL_AUGMENTATION_SPACE, + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + } + + def __init__( + self, + severity: int = 3, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._PARAMETER_MAX = 10 + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + + def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + + def _apply_augmix(self, input: Any) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, torch.Tensor): + image = input + elif isinstance(input, PIL.Image.Image): + image = pil_to_tensor(input) + else: + return input + + augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE + + image_size = F.get_image_size(image) + + orig_dims = list(image.shape) + batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims) + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet + # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. + m = self._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. + combined_weights = self._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].view([batch_dims[0], -1]) + + mix = m[:, 0].view(batch_dims) * batch + for i in range(self.mixture_width): + aug = batch + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + for _ in range(depth): + transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) + + magnitudes = magnitudes_fn(self._PARAMETER_MAX, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + aug = self._apply_transform_to_item( + image, transform_id, magnitude, interpolation=self.interpolation, fill=self._parse_fill(image) + ) + mix.add_(combined_weights[:, i].view(batch_dims) * aug) + mix = mix.view(orig_dims).to(dtype=image.dtype) + + if isinstance(input, features.Image): + return features.Image.new_like(input, mix) + elif isinstance(input, torch.Tensor): + return mix + else: # isinstance(input, PIL.Image.Image): + return to_pil_image(mix) + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + self._query_image(sample) + return apply_recursively(self._apply_augmix, sample) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 24d794a2cb4..c359abf8d13 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Iterator import PIL.Image import torch @@ -6,14 +6,19 @@ from torchvision.prototype.utils._internal import query_recursively -def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: - def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: - if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): - return input +def _extract_image(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: + if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + return input + + return None + - return None +def query_images(sample: Any) -> Iterator[Union[PIL.Image.Image, torch.Tensor, features.Image]]: + return query_recursively(_extract_image, sample) + +def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: try: - return next(query_recursively(fn, sample)) + return next(query_images(sample)) except StopIteration: raise TypeError("No image was found in the sample") From 67c30569832eeb63f3ce59dbdf28d6653943ff49 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 28 Feb 2022 14:48:06 +0100 Subject: [PATCH 2/4] cleanup --- .../prototype/transforms/_auto_augment.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index ad99b3fee2e..e08981d9479 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -37,14 +37,22 @@ def _query_image(self, sample: Any) -> Union[PIL.Image.Image, torch.Tensor, feat return images[0] def _parse_fill(self, sample: Any) -> Optional[List[float]]: + fill = self.fill + + if fill is None: + return fill + image = self._query_image(sample) - num_channels, *_ = get_image_dimensions(image) - fill = self.fill + if not isinstance(image, torch.Tensor): + return fill + if isinstance(fill, (int, float)): + num_channels, *_ = get_image_dimensions(image) fill = [float(fill)] * num_channels - elif fill is not None: + else: fill = [float(f) for f in fill] + return fill def _dispatch( @@ -354,7 +362,6 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - self._query_image(sample) image = self._query_image(sample) _, height, width = get_image_dimensions(image) @@ -482,6 +489,7 @@ def _apply_augmix(self, input: Any) -> Any: augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE _, height, width = get_image_dimensions(image) + fill = self._parse_fill(image) orig_dims = list(image.shape) batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims) @@ -514,7 +522,7 @@ def _apply_augmix(self, input: Any) -> Any: magnitude = 0.0 aug = self._apply_transform_to_item( - image, transform_id, magnitude, interpolation=self.interpolation, fill=self._parse_fill(image) + image, transform_id, magnitude, interpolation=self.interpolation, fill=fill ) mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix = mix.view(orig_dims).to(dtype=image.dtype) From b6375e6d5d2fa101433ca8dc578d6984d2bf19eb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 28 Feb 2022 16:44:54 +0100 Subject: [PATCH 3/4] refactor auto augment subclasses to only trnasform a single image --- .../prototype/transforms/_auto_augment.py | 198 ++++++++++-------- torchvision/prototype/transforms/_utils.py | 21 +- torchvision/prototype/utils/_internal.py | 17 +- 3 files changed, 126 insertions(+), 110 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index e08981d9479..eb385defcad 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,4 +1,3 @@ -import functools import math from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union @@ -6,15 +5,27 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F -from torchvision.prototype.utils._internal import apply_recursively +from torchvision.prototype.utils._internal import query_recursively from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import query_images, get_image_dimensions +from ._utils import get_image_dimensions K = TypeVar("K") V = TypeVar("V") +def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any: + if not id: + return item + + parent = sample + for key in id[:-1]: + parent = parent[key] + + parent[id[-1]] = item + return sample + + class _AutoAugmentBase(Transform): def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None @@ -28,34 +39,47 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _query_image(self, sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: - images = list(query_images(sample)) + def _check_support(self, input: Any) -> None: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + + def _extract_image( + self, sample: Any + ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: + def fn( + id: Tuple[Any, ...], input: Any + ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: + if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + return id, input + + self._check_support(input) + return None + + images = list(query_recursively(fn, sample)) + if not images: + raise TypeError("Found no image in the sample.") if len(images) > 1: raise TypeError( f"Auto augment transformations are only properly defined for a single image, but found {len(images)}." ) return images[0] - def _parse_fill(self, sample: Any) -> Optional[List[float]]: + def _parse_fill( + self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int + ) -> Optional[List[float]]: fill = self.fill - if fill is None: - return fill - - image = self._query_image(sample) - - if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image) or fill is None: return fill if isinstance(fill, (int, float)): - num_channels, *_ = get_image_dimensions(image) fill = [float(fill)] * num_channels else: fill = [float(f) for f in fill] return fill - def _dispatch( + def _dispatch_image_kernels( self, image_tensor_kernel: Callable, image_pil_kernel: Callable, @@ -63,33 +87,29 @@ def _dispatch( *args: Any, **kwargs: Any, ) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): + if isinstance(input, features.Image): output = image_tensor_kernel(input, *args, **kwargs) return features.Image.new_like(input, output) elif isinstance(input, torch.Tensor): return image_tensor_kernel(input, *args, **kwargs) - elif isinstance(input, PIL.Image.Image): + else: # isinstance(input, PIL.Image.Image): return image_pil_kernel(input, *args, **kwargs) - else: - return input - def _apply_transform_to_item( + def _apply_image_transform( self, - item: Any, + image: Any, transform_id: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]], ) -> Any: if transform_id == "Identity": - return item + return image elif transform_id == "ShearX": - return self._dispatch( + return self._dispatch_image_kernels( F.affine_image_tensor, F.affine_image_pil, - item, + image, angle=0.0, translate=[0, 0], scale=1.0, @@ -98,10 +118,10 @@ def _apply_transform_to_item( fill=fill, ) elif transform_id == "ShearY": - return self._dispatch( + return self._dispatch_image_kernels( F.affine_image_tensor, F.affine_image_pil, - item, + image, angle=0.0, translate=[0, 0], scale=1.0, @@ -110,10 +130,10 @@ def _apply_transform_to_item( fill=fill, ) elif transform_id == "TranslateX": - return self._dispatch( + return self._dispatch_image_kernels( F.affine_image_tensor, F.affine_image_pil, - item, + image, angle=0.0, translate=[int(magnitude), 0], scale=1.0, @@ -122,10 +142,10 @@ def _apply_transform_to_item( fill=fill, ) elif transform_id == "TranslateY": - return self._dispatch( + return self._dispatch_image_kernels( F.affine_image_tensor, F.affine_image_pil, - item, + image, angle=0.0, translate=[0, int(magnitude)], scale=1.0, @@ -134,57 +154,49 @@ def _apply_transform_to_item( fill=fill, ) elif transform_id == "Rotate": - return self._dispatch(F.rotate_image_tensor, F.rotate_image_pil, item, angle=magnitude) + return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude) elif transform_id == "Brightness": - return self._dispatch( + return self._dispatch_image_kernels( F.adjust_brightness_image_tensor, F.adjust_brightness_image_pil, - item, + image, brightness_factor=1.0 + magnitude, ) elif transform_id == "Color": - return self._dispatch( + return self._dispatch_image_kernels( F.adjust_saturation_image_tensor, F.adjust_saturation_image_pil, - item, + image, saturation_factor=1.0 + magnitude, ) elif transform_id == "Contrast": - return self._dispatch( - F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, item, contrast_factor=1.0 + magnitude + return self._dispatch_image_kernels( + F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude ) elif transform_id == "Sharpness": - return self._dispatch( + return self._dispatch_image_kernels( F.adjust_sharpness_image_tensor, F.adjust_sharpness_image_pil, - item, + image, sharpness_factor=1.0 + magnitude, ) elif transform_id == "Posterize": - return self._dispatch(F.posterize_image_tensor, F.posterize_image_pil, item, bits=int(magnitude)) + return self._dispatch_image_kernels( + F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude) + ) elif transform_id == "Solarize": - return self._dispatch(F.solarize_image_tensor, F.solarize_image_pil, item, threshold=magnitude) + return self._dispatch_image_kernels( + F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude + ) elif transform_id == "AutoContrast": - return self._dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, item) + return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image) elif transform_id == "Equalize": - return self._dispatch(F.equalize_image_tensor, F.equalize_image_pil, item) + return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image) elif transform_id == "Invert": - return self._dispatch(F.invert_image_tensor, F.invert_image_pil, item) + return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image) else: raise ValueError(f"No transform available for {transform_id}") - def _apply_transform_to_sample(self, sample: Any, transform_id: str, magnitude: float) -> Any: - return apply_recursively( - functools.partial( - self._apply_transform_to_item, - transform_id=transform_id, - magnitude=magnitude, - interpolation=self.interpolation, - fill=self._parse_fill(sample), - ), - sample, - ) - class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { @@ -307,8 +319,9 @@ def _get_policies( def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = self._query_image(sample) - _, height, width = get_image_dimensions(image) + id, image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(image) + fill = self._parse_fill(image, num_channels) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -326,9 +339,11 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - sample = self._apply_transform_to_sample(sample, transform_id, magnitude) + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=fill + ) - return sample + return _put_into_sample(sample, id, image) class RandAugment(_AutoAugmentBase): @@ -363,8 +378,9 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = self._query_image(sample) - _, height, width = get_image_dimensions(image) + id, image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(image) + fill = self._parse_fill(image, num_channels) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -377,9 +393,11 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - sample = self._apply_transform_to_sample(sample, transform_id, magnitude) + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=fill + ) - return sample + return _put_into_sample(sample, id, image) class TrivialAugmentWide(_AutoAugmentBase): @@ -412,8 +430,9 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = self._query_image(sample) - _, height, width = get_image_dimensions(image) + id, image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(image) + fill = self._parse_fill(image, num_channels) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -425,7 +444,11 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - return self._apply_transform_to_sample(sample, transform_id, magnitude) + return _put_into_sample( + sample, + id, + self._apply_image_transform(sample, transform_id, magnitude, interpolation=self.interpolation, fill=fill), + ) class AugMix(_AutoAugmentBase): @@ -476,20 +499,18 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: # Must be on a separate method so that we can overwrite it in tests. return torch._sample_dirichlet(params) - def _apply_augmix(self, input: Any) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, torch.Tensor): - image = input - elif isinstance(input, PIL.Image.Image): - image = pil_to_tensor(input) - else: - return input + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + id, orig_image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(orig_image) + fill = self._parse_fill(orig_image, num_channels) - augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE + if isinstance(orig_image, torch.Tensor): + image = orig_image + else: # isinstance(input, PIL.Image.Image): + image = pil_to_tensor(orig_image) - _, height, width = get_image_dimensions(image) - fill = self._parse_fill(image) + augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE orig_dims = list(image.shape) batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims) @@ -521,20 +542,15 @@ def _apply_augmix(self, input: Any) -> Any: else: magnitude = 0.0 - aug = self._apply_transform_to_item( + aug = self._apply_image_transform( image, transform_id, magnitude, interpolation=self.interpolation, fill=fill ) mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix = mix.view(orig_dims).to(dtype=image.dtype) - if isinstance(input, features.Image): - return features.Image.new_like(input, mix) - elif isinstance(input, torch.Tensor): - return mix - else: # isinstance(input, PIL.Image.Image): - return to_pil_image(mix) + if isinstance(orig_image, features.Image): + mix = features.Image.new_like(orig_image, mix) + elif isinstance(orig_image, PIL.Image.Image): + mix = to_pil_image(mix) - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - self._query_image(sample) - return apply_recursively(self._apply_augmix, sample) + return _put_into_sample(sample, id, mix) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index bc66a46e09c..f5107908e41 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, Union, Iterator +from typing import Any, Optional, Tuple, Union import PIL.Image import torch @@ -8,20 +8,17 @@ from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil -def _extract_image(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: - if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): - return input - - return None - - -def query_images(sample: Any) -> Iterator[Union[PIL.Image.Image, torch.Tensor, features.Image]]: - return query_recursively(_extract_image, sample) +def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: + def fn( + id: Tuple[Any, ...], input: Any + ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: + if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + return id, input + return None -def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: try: - return next(query_images(sample)) + return next(query_recursively(fn, sample))[1] except StopIteration: raise TypeError("No image was found in the sample") diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 2e38471ea65..366a19f2bbc 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: return fn(obj) -def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]: +def query_recursively( + fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = () +) -> Iterator[D]: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... - if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance( - obj, collections.abc.Mapping - ): - for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj: - yield from query_recursively(fn, item) + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + for idx, item in enumerate(obj): + yield from query_recursively(fn, item, id=(*id, idx)) + elif isinstance(obj, collections.abc.Mapping): + for key, item in obj.items(): + yield from query_recursively(fn, item, id=(*id, key)) else: - result = fn(obj) + result = fn(id, obj) if result is not None: yield result From 3a455d8c45155212e95ffd790fb1c51f7cf2cc83 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 28 Feb 2022 17:21:03 +0100 Subject: [PATCH 4/4] address review comments --- torchvision/prototype/transforms/_auto_augment.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index eb385defcad..892162fa296 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -39,7 +39,7 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _check_support(self, input: Any) -> None: + def _check_unsupported(self, input: Any) -> None: if isinstance(input, (features.BoundingBox, features.SegmentationMask)): raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") @@ -52,7 +52,7 @@ def fn( if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): return id, input - self._check_support(input) + self._check_unsupported(input) return None images = list(query_recursively(fn, sample)) @@ -444,11 +444,8 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - return _put_into_sample( - sample, - id, - self._apply_image_transform(sample, transform_id, magnitude, interpolation=self.interpolation, fill=fill), - ) + image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill) + return _put_into_sample(sample, id, image) class AugMix(_AutoAugmentBase): @@ -543,7 +540,7 @@ def forward(self, *inputs: Any) -> Any: magnitude = 0.0 aug = self._apply_image_transform( - image, transform_id, magnitude, interpolation=self.interpolation, fill=fill + aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill ) mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix = mix.view(orig_dims).to(dtype=image.dtype)