diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index a4e845d75c3..5edd18890a8 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -4,7 +4,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix -from ._color import ColorJitter, RandomPhotometricDistort +from ._color import ColorJitter, RandomPhotometricDistort, RandomEqualize from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( Resize, diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index cf3ceaf7ede..960020baff8 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -8,6 +8,7 @@ from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms import functional as _F +from ._transform import _RandomApplyTransform from ._utils import is_simple_tensor, get_image_dimensions, query_image T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) @@ -188,3 +189,19 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if params["channel_shuffle"]: input = self._channel_shuffle(input) return input + + +class RandomEqualize(_RandomApplyTransform): + def __init__(self, p: float = 0.5): + super().__init__(p=p) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.equalize_image_tensor(input) + return features.Image.new_like(input, output) + elif is_simple_tensor(input): + return F.equalize_image_tensor(input) + elif isinstance(input, PIL.Image.Image): + return F.equalize_image_pil(input) + else: + return input