Skip to content

Commit d677be7

Browse files
prabhat00155vfdev-5
authored andcommitted
[fbsync] Adding AugMix implementation (#5411)
Summary: * Adding basic augmix implementation. * Finish the implementation. * Add tests and documentation. * Fix tests. * Simplify code. * Speed optimizations. * Per image weights instead of per batch. * Fix tests. * Update torchvision/transforms/autoaugment.py * Changing the default severity value to get by default the same strength as RandAugment. Reviewed By: jdsgomes Differential Revision: D34475319 fbshipit-source-id: 4637ad23deace03cf1f96b5c19a310c360f179d5 Co-authored-by: vfdev <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent cfc76bf commit d677be7

File tree

6 files changed

+214
-2
lines changed

6 files changed

+214
-2
lines changed

docs/source/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran
198198
AutoAugment
199199
RandAugment
200200
TrivialAugmentWide
201+
AugMix
201202

202203
.. _functional_transforms:
203204

gallery/plot_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
263263
imgs = [augmenter(orig_img) for _ in range(4)]
264264
plot(imgs)
265265

266+
####################################
267+
# AugMix
268+
# ~~~~~~
269+
# The :class:`~torchvision.transforms.AugMix` transform automatically augments the data.
270+
augmenter = T.AugMix()
271+
imgs = [augmenter(orig_img) for _ in range(4)]
272+
plot(imgs)
273+
266274
####################################
267275
# Randomly-applied transforms
268276
# ---------------------------

references/classification/presets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(
2222
trans.append(autoaugment.RandAugment(interpolation=interpolation))
2323
elif auto_augment_policy == "ta_wide":
2424
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
25+
elif auto_augment_policy == "augmix":
26+
trans.append(autoaugment.AugMix(interpolation=interpolation))
2527
else:
2628
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
2729
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))

test/test_transforms.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
16011601
transform.__repr__()
16021602

16031603

1604+
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
1605+
@pytest.mark.parametrize("severity", [1, 10])
1606+
@pytest.mark.parametrize("mixture_width", [1, 2])
1607+
@pytest.mark.parametrize("chain_depth", [-1, 2])
1608+
@pytest.mark.parametrize("all_ops", [True, False])
1609+
@pytest.mark.parametrize("grayscale", [True, False])
1610+
def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
1611+
random.seed(42)
1612+
img = Image.open(GRACE_HOPPER)
1613+
if grayscale:
1614+
img, fill = _get_grayscale_test_image(img, fill)
1615+
transform = transforms.AugMix(
1616+
fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops
1617+
)
1618+
for _ in range(100):
1619+
img = transform(img)
1620+
transform.__repr__()
1621+
1622+
16041623
def test_random_crop():
16051624
height = random.randint(10, 32) * 2
16061625
width = random.randint(10, 32) * 2

test/test_transforms_tensor.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,38 @@ def test_trivialaugmentwide(device, fill):
720720
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
721721

722722

723-
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
723+
@pytest.mark.parametrize("device", cpu_and_gpu())
724+
@pytest.mark.parametrize(
725+
"fill",
726+
[
727+
None,
728+
85,
729+
(10, -10, 10),
730+
0.7,
731+
[0.0, 0.0, 0.0],
732+
[
733+
1,
734+
],
735+
1,
736+
],
737+
)
738+
def test_augmix(device, fill):
739+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
740+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
741+
742+
class DeterministicAugMix(T.AugMix):
743+
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
744+
# patch the method to ensure that the order of rand calls doesn't affect the outcome
745+
return params.softmax(dim=-1)
746+
747+
transform = DeterministicAugMix(fill=fill)
748+
s_transform = torch.jit.script(transform)
749+
for _ in range(25):
750+
_test_transform_vs_scripted(transform, s_transform, tensor)
751+
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
752+
753+
754+
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
724755
def test_autoaugment_save(augmentation, tmpdir):
725756
transform = augmentation()
726757
s_transform = torch.jit.script(transform)

torchvision/transforms/autoaugment.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import functional as F, InterpolationMode
99

10-
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
10+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
1111

1212

1313
def _apply_op(
@@ -458,3 +458,154 @@ def __repr__(self) -> str:
458458
f")"
459459
)
460460
return s
461+
462+
463+
class AugMix(torch.nn.Module):
464+
r"""AugMix data augmentation method based on
465+
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
466+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
467+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
468+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
469+
470+
Args:
471+
severity (int): The severity of base augmentation operators. Default is ``3``.
472+
mixture_width (int): The number of augmentation chains. Default is ``3``.
473+
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
474+
Default is ``-1``.
475+
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
476+
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
477+
interpolation (InterpolationMode): Desired interpolation enum defined by
478+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
479+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
480+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
481+
image. If given a number, the value is used for all bands respectively.
482+
"""
483+
484+
def __init__(
485+
self,
486+
severity: int = 3,
487+
mixture_width: int = 3,
488+
chain_depth: int = -1,
489+
alpha: float = 1.0,
490+
all_ops: bool = True,
491+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
492+
fill: Optional[List[float]] = None,
493+
) -> None:
494+
super().__init__()
495+
self._PARAMETER_MAX = 10
496+
if not (1 <= severity <= self._PARAMETER_MAX):
497+
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
498+
self.severity = severity
499+
self.mixture_width = mixture_width
500+
self.chain_depth = chain_depth
501+
self.alpha = alpha
502+
self.all_ops = all_ops
503+
self.interpolation = interpolation
504+
self.fill = fill
505+
506+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
507+
s = {
508+
# op_name: (magnitudes, signed)
509+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
510+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
511+
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
512+
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
513+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
514+
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
515+
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
516+
"AutoContrast": (torch.tensor(0.0), False),
517+
"Equalize": (torch.tensor(0.0), False),
518+
}
519+
if self.all_ops:
520+
s.update(
521+
{
522+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
523+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
524+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
525+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
526+
}
527+
)
528+
return s
529+
530+
@torch.jit.unused
531+
def _pil_to_tensor(self, img) -> Tensor:
532+
return F.pil_to_tensor(img)
533+
534+
@torch.jit.unused
535+
def _tensor_to_pil(self, img: Tensor):
536+
return F.to_pil_image(img)
537+
538+
def _sample_dirichlet(self, params: Tensor) -> Tensor:
539+
# Must be on a separate method so that we can overwrite it in tests.
540+
return torch._sample_dirichlet(params)
541+
542+
def forward(self, orig_img: Tensor) -> Tensor:
543+
"""
544+
img (PIL Image or Tensor): Image to be transformed.
545+
546+
Returns:
547+
PIL Image or Tensor: Transformed image.
548+
"""
549+
fill = self.fill
550+
if isinstance(orig_img, Tensor):
551+
img = orig_img
552+
if isinstance(fill, (int, float)):
553+
fill = [float(fill)] * F.get_image_num_channels(img)
554+
elif fill is not None:
555+
fill = [float(f) for f in fill]
556+
else:
557+
img = self._pil_to_tensor(orig_img)
558+
559+
op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img))
560+
561+
orig_dims = list(img.shape)
562+
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
563+
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
564+
565+
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
566+
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
567+
m = self._sample_dirichlet(
568+
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
569+
)
570+
571+
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
572+
combined_weights = self._sample_dirichlet(
573+
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
574+
) * m[:, 1].view([batch_dims[0], -1])
575+
576+
mix = m[:, 0].view(batch_dims) * batch
577+
for i in range(self.mixture_width):
578+
aug = batch
579+
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
580+
for _ in range(depth):
581+
op_index = int(torch.randint(len(op_meta), (1,)).item())
582+
op_name = list(op_meta.keys())[op_index]
583+
magnitudes, signed = op_meta[op_name]
584+
magnitude = (
585+
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
586+
if magnitudes.ndim > 0
587+
else 0.0
588+
)
589+
if signed and torch.randint(2, (1,)):
590+
magnitude *= -1.0
591+
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
592+
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
593+
mix = mix.view(orig_dims).to(dtype=img.dtype)
594+
595+
if not isinstance(orig_img, Tensor):
596+
return self._tensor_to_pil(mix)
597+
return mix
598+
599+
def __repr__(self) -> str:
600+
s = (
601+
f"{self.__class__.__name__}("
602+
f"severity={self.severity}"
603+
f", mixture_width={self.mixture_width}"
604+
f", chain_depth={self.chain_depth}"
605+
f", alpha={self.alpha}"
606+
f", all_ops={self.all_ops}"
607+
f", interpolation={self.interpolation}"
608+
f", fill={self.fill}"
609+
f")"
610+
)
611+
return s

0 commit comments

Comments
 (0)