diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 190867523eb..52101368800 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -114,6 +114,7 @@ def test_mixup_cutmix(self, transform, input): transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AutoAugment(), + transforms.AugMix(), ) ] ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 73d45097a93..98ad7ae0d74 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -5,7 +5,7 @@ from ._transform import Transform # usort: skip from ._augment import RandomErasing, RandomMixup, RandomCutmix -from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment +from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 78cbf958ccd..892162fa296 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,18 +1,31 @@ import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Union import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F -from torchvision.prototype.utils._internal import apply_recursively +from torchvision.prototype.utils._internal import query_recursively +from torchvision.transforms.functional import pil_to_tensor, to_pil_image -from ._utils import query_image, get_image_dimensions +from ._utils import get_image_dimensions K = TypeVar("K") V = TypeVar("V") +def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any: + if not id: + return item + + parent = sample + for key in id[:-1]: + parent = parent[key] + + parent[id[-1]] = item + return sample + + class _AutoAugmentBase(Transform): def __init__( self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None @@ -26,134 +39,163 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any: - def dispatch( - image_tensor_kernel: Callable, - image_pil_kernel: Callable, - input: Any, - *args: Any, - **kwargs: Any, - ) -> Any: - if isinstance(input, (features.BoundingBox, features.SegmentationMask)): - raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") - elif isinstance(input, features.Image): - output = image_tensor_kernel(input, *args, **kwargs) - return features.Image.new_like(input, output) - elif isinstance(input, torch.Tensor): - return image_tensor_kernel(input, *args, **kwargs) - elif isinstance(input, PIL.Image.Image): - return image_pil_kernel(input, *args, **kwargs) - else: - return input + def _check_unsupported(self, input: Any) -> None: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + + def _extract_image( + self, sample: Any + ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: + def fn( + id: Tuple[Any, ...], input: Any + ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: + if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + return id, input + + self._check_unsupported(input) + return None + + images = list(query_recursively(fn, sample)) + if not images: + raise TypeError("Found no image in the sample.") + if len(images) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single image, but found {len(images)}." + ) + return images[0] + + def _parse_fill( + self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int + ) -> Optional[List[float]]: + fill = self.fill - image = query_image(sample) - num_channels, *_ = get_image_dimensions(image) + if isinstance(image, PIL.Image.Image) or fill is None: + return fill - fill = self.fill if isinstance(fill, (int, float)): fill = [float(fill)] * num_channels - elif fill is not None: + else: fill = [float(f) for f in fill] - interpolation = self.interpolation - - def transform(input: Any) -> Any: - if type(input) in {features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") - elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)): - return input - - if transform_id == "Identity": - return input - elif transform_id == "ShearX": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "ShearY": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "TranslateX": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "TranslateY": - return dispatch( - F.affine_image_tensor, - F.affine_image_pil, - input, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ) - elif transform_id == "Rotate": - return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude) - elif transform_id == "Brightness": - return dispatch( - F.adjust_brightness_image_tensor, - F.adjust_brightness_image_pil, - input, - brightness_factor=1.0 + magnitude, - ) - elif transform_id == "Color": - return dispatch( - F.adjust_saturation_image_tensor, - F.adjust_saturation_image_pil, - input, - saturation_factor=1.0 + magnitude, - ) - elif transform_id == "Contrast": - return dispatch( - F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude - ) - elif transform_id == "Sharpness": - return dispatch( - F.adjust_sharpness_image_tensor, - F.adjust_sharpness_image_pil, - input, - sharpness_factor=1.0 + magnitude, - ) - elif transform_id == "Posterize": - return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude)) - elif transform_id == "Solarize": - return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude) - elif transform_id == "AutoContrast": - return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input) - elif transform_id == "Equalize": - return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input) - elif transform_id == "Invert": - return dispatch(F.invert_image_tensor, F.invert_image_pil, input) - else: - raise ValueError(f"No transform available for {transform_id}") - - return apply_recursively(transform, sample) + return fill + + def _dispatch_image_kernels( + self, + image_tensor_kernel: Callable, + image_pil_kernel: Callable, + input: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + if isinstance(input, features.Image): + output = image_tensor_kernel(input, *args, **kwargs) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return image_tensor_kernel(input, *args, **kwargs) + else: # isinstance(input, PIL.Image.Image): + return image_pil_kernel(input, *args, **kwargs) + + def _apply_image_transform( + self, + image: Any, + transform_id: str, + magnitude: float, + interpolation: InterpolationMode, + fill: Optional[List[float]], + ) -> Any: + if transform_id == "Identity": + return image + elif transform_id == "ShearX": + return self._dispatch_image_kernels( + F.affine_image_tensor, + F.affine_image_pil, + image, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "ShearY": + return self._dispatch_image_kernels( + F.affine_image_tensor, + F.affine_image_pil, + image, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateX": + return self._dispatch_image_kernels( + F.affine_image_tensor, + F.affine_image_pil, + image, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateY": + return self._dispatch_image_kernels( + F.affine_image_tensor, + F.affine_image_pil, + image, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "Rotate": + return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude) + elif transform_id == "Brightness": + return self._dispatch_image_kernels( + F.adjust_brightness_image_tensor, + F.adjust_brightness_image_pil, + image, + brightness_factor=1.0 + magnitude, + ) + elif transform_id == "Color": + return self._dispatch_image_kernels( + F.adjust_saturation_image_tensor, + F.adjust_saturation_image_pil, + image, + saturation_factor=1.0 + magnitude, + ) + elif transform_id == "Contrast": + return self._dispatch_image_kernels( + F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude + ) + elif transform_id == "Sharpness": + return self._dispatch_image_kernels( + F.adjust_sharpness_image_tensor, + F.adjust_sharpness_image_pil, + image, + sharpness_factor=1.0 + magnitude, + ) + elif transform_id == "Posterize": + return self._dispatch_image_kernels( + F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude) + ) + elif transform_id == "Solarize": + return self._dispatch_image_kernels( + F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude + ) + elif transform_id == "AutoContrast": + return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image) + elif transform_id == "Equalize": + return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image) + elif transform_id == "Invert": + return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image) + else: + raise ValueError(f"No transform available for {transform_id}") class AutoAugment(_AutoAugmentBase): @@ -277,8 +319,9 @@ def _get_policies( def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = query_image(sample) - _, height, width = get_image_dimensions(image) + id, image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(image) + fill = self._parse_fill(image, num_channels) policy = self._policies[int(torch.randint(len(self._policies), ()))] @@ -296,9 +339,11 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, transform_id, magnitude) + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=fill + ) - return sample + return _put_into_sample(sample, id, image) class RandAugment(_AutoAugmentBase): @@ -333,8 +378,9 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = query_image(sample) - _, height, width = get_image_dimensions(image) + id, image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(image) + fill = self._parse_fill(image, num_channels) for _ in range(self.num_ops): transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -347,9 +393,11 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, transform_id, magnitude) + image = self._apply_image_transform( + image, transform_id, magnitude, interpolation=self.interpolation, fill=fill + ) - return sample + return _put_into_sample(sample, id, image) class TrivialAugmentWide(_AutoAugmentBase): @@ -382,8 +430,9 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - image = query_image(sample) - _, height, width = get_image_dimensions(image) + id, image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(image) + fill = self._parse_fill(image, num_channels) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) @@ -395,4 +444,110 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - return self._apply_transform(sample, transform_id, magnitude) + image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill) + return _put_into_sample(sample, id, image) + + +class AugMix(_AutoAugmentBase): + _PARTIAL_AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": ( + lambda num_bins, image_size: cast(torch.Tensor, 4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + .round() + .int(), + False, + ), + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + } + _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, Tuple[int, int]], Optional[torch.Tensor]], bool]] = { + **_PARTIAL_AUGMENTATION_SPACE, + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + } + + def __init__( + self, + severity: int = 3, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._PARAMETER_MAX = 10 + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + + def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + id, orig_image = self._extract_image(sample) + num_channels, height, width = get_image_dimensions(orig_image) + fill = self._parse_fill(orig_image, num_channels) + + if isinstance(orig_image, torch.Tensor): + image = orig_image + else: # isinstance(input, PIL.Image.Image): + image = pil_to_tensor(orig_image) + + augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE + + orig_dims = list(image.shape) + batch = image.view([1] * max(4 - image.ndim, 0) + orig_dims) + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet + # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. + m = self._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. + combined_weights = self._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].view([batch_dims[0], -1]) + + mix = m[:, 0].view(batch_dims) * batch + for i in range(self.mixture_width): + aug = batch + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + for _ in range(depth): + transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) + + magnitudes = magnitudes_fn(self._PARAMETER_MAX, (height, width)) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + aug = self._apply_image_transform( + aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill + ) + mix.add_(combined_weights[:, i].view(batch_dims) * aug) + mix = mix.view(orig_dims).to(dtype=image.dtype) + + if isinstance(orig_image, features.Image): + mix = features.Image.new_like(orig_image, mix) + elif isinstance(orig_image, PIL.Image.Image): + mix = to_pil_image(mix) + + return _put_into_sample(sample, id, mix) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index d8677d451c8..f5107908e41 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -9,14 +9,16 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: - def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: + def fn( + id: Tuple[Any, ...], input: Any + ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): - return input + return id, input return None try: - return next(query_recursively(fn, sample)) + return next(query_recursively(fn, sample))[1] except StopIteration: raise TypeError("No image was found in the sample") diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 2e38471ea65..366a19f2bbc 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -312,15 +312,18 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: return fn(obj) -def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]: +def query_recursively( + fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = () +) -> Iterator[D]: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... - if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance( - obj, collections.abc.Mapping - ): - for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj: - yield from query_recursively(fn, item) + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + for idx, item in enumerate(obj): + yield from query_recursively(fn, item, id=(*id, idx)) + elif isinstance(obj, collections.abc.Mapping): + for key, item in obj.items(): + yield from query_recursively(fn, item, id=(*id, key)) else: - result = fn(obj) + result = fn(id, obj) if result is not None: yield result