Skip to content

Commit f49edd3

Browse files
vfdev-5datumbox
andauthored
[proto] Fixed fill type in AA (#6621)
* [proto] Fixed fill type in AA * Fixed missed typehints * Set fill as None by default * Another fix Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 0fcfaa1 commit f49edd3

File tree

4 files changed

+77
-67
lines changed

4 files changed

+77
-67
lines changed

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import math
2-
import numbers
3-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
2+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
43

54
import PIL.Image
65
import torch
@@ -10,7 +9,7 @@
109
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
1110
from torchvision.prototype.transforms.functional._meta import get_chw
1211

13-
from ._utils import _isinstance
12+
from ._utils import _isinstance, _setup_fill_arg, FillType
1413

1514
K = TypeVar("K")
1615
V = TypeVar("V")
@@ -21,14 +20,11 @@ def __init__(
2120
self,
2221
*,
2322
interpolation: InterpolationMode = InterpolationMode.NEAREST,
24-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
23+
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
2524
) -> None:
2625
super().__init__()
2726
self.interpolation = interpolation
28-
29-
if not isinstance(fill, (numbers.Number, tuple, list)):
30-
raise TypeError("Got inappropriate fill arg")
31-
self.fill = fill
27+
self.fill = _setup_fill_arg(fill)
3228

3329
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
3430
keys = tuple(dct.keys())
@@ -63,19 +59,14 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
6359

6460
def _apply_image_transform(
6561
self,
66-
image: Any,
62+
image: Union[torch.Tensor, PIL.Image.Image, features.Image],
6763
transform_id: str,
6864
magnitude: float,
6965
interpolation: InterpolationMode,
70-
fill: Union[int, float, Sequence[int], Sequence[float]],
66+
fill: Union[Dict[Type, FillType], Dict[Type, None]],
7167
) -> Any:
72-
73-
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
74-
# So, we have to put fill as None if fill == 0
75-
# This is due to BC with stable API which has fill = None by default
76-
fill_ = F._geometry._convert_fill_arg(fill)
77-
if isinstance(fill, int) and fill == 0:
78-
fill_ = None
68+
fill_ = fill[type(image)]
69+
fill_ = F._geometry._convert_fill_arg(fill_)
7970

8071
if transform_id == "Identity":
8172
return image
@@ -186,7 +177,7 @@ def __init__(
186177
self,
187178
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
188179
interpolation: InterpolationMode = InterpolationMode.NEAREST,
189-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
180+
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
190181
) -> None:
191182
super().__init__(interpolation=interpolation, fill=fill)
192183
self.policy = policy
@@ -286,7 +277,7 @@ def forward(self, *inputs: Any) -> Any:
286277
sample = inputs if len(inputs) > 1 else inputs[0]
287278

288279
id, image = self._extract_image(sample)
289-
num_channels, height, width = get_chw(image)
280+
_, height, width = get_chw(image)
290281

291282
policy = self._policies[int(torch.randint(len(self._policies), ()))]
292283

@@ -346,7 +337,7 @@ def __init__(
346337
magnitude: int = 9,
347338
num_magnitude_bins: int = 31,
348339
interpolation: InterpolationMode = InterpolationMode.NEAREST,
349-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
340+
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
350341
) -> None:
351342
super().__init__(interpolation=interpolation, fill=fill)
352343
self.num_ops = num_ops
@@ -402,7 +393,7 @@ def __init__(
402393
self,
403394
num_magnitude_bins: int = 31,
404395
interpolation: InterpolationMode = InterpolationMode.NEAREST,
405-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
396+
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
406397
):
407398
super().__init__(interpolation=interpolation, fill=fill)
408399
self.num_magnitude_bins = num_magnitude_bins
@@ -462,7 +453,7 @@ def __init__(
462453
alpha: float = 1.0,
463454
all_ops: bool = True,
464455
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
465-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
456+
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
466457
) -> None:
467458
super().__init__(interpolation=interpolation, fill=fill)
468459
self._PARAMETER_MAX = 10

torchvision/prototype/transforms/_deprecated.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from typing_extensions import Literal
1212

1313
from ._transform import _RandomApplyTransform
14-
from ._utils import query_chw
15-
16-
17-
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
14+
from ._utils import DType, query_chw
1815

1916

2017
class ToTensor(Transform):

torchvision/prototype/transforms/_geometry.py

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
import numbers
33
import warnings
4-
from collections import defaultdict
54
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
65

76
import PIL.Image
@@ -14,11 +13,20 @@
1413
from typing_extensions import Literal
1514

1615
from ._transform import _RandomApplyTransform
17-
from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw
18-
19-
20-
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
21-
FillType = Union[int, float, Sequence[int], Sequence[float]]
16+
from ._utils import (
17+
_check_padding_arg,
18+
_check_padding_mode_arg,
19+
_check_sequence_input,
20+
_setup_angle,
21+
_setup_fill_arg,
22+
_setup_size,
23+
DType,
24+
FillType,
25+
has_all,
26+
has_any,
27+
query_bounding_box,
28+
query_chw,
29+
)
2230

2331

2432
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -201,40 +209,6 @@ def forward(self, *inputs: Any) -> Any:
201209
return super().forward(*inputs)
202210

203211

204-
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
205-
if isinstance(fill, dict):
206-
for key, value in fill.items():
207-
# Check key for type
208-
_check_fill_arg(value)
209-
else:
210-
if not isinstance(fill, (numbers.Number, tuple, list)):
211-
raise TypeError("Got inappropriate fill arg")
212-
213-
214-
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
215-
_check_fill_arg(fill)
216-
217-
if isinstance(fill, dict):
218-
return fill
219-
220-
return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]
221-
222-
223-
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
224-
if not isinstance(padding, (numbers.Number, tuple, list)):
225-
raise TypeError("Got inappropriate padding arg")
226-
227-
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
228-
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
229-
230-
231-
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
232-
# https://github.com/pytorch/vision/issues/6250
233-
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
234-
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
235-
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
236-
237-
238212
class Pad(Transform):
239213
def __init__(
240214
self,

torchvision/prototype/transforms/_utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,61 @@
1-
from typing import Any, Callable, Tuple, Type, Union
1+
import numbers
2+
from collections import defaultdict
3+
4+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
25

36
import PIL.Image
7+
8+
import torch
49
from torch.utils._pytree import tree_flatten
510
from torchvision._utils import sequence_to_str
611
from torchvision.prototype import features
712

813
from torchvision.prototype.transforms.functional._meta import get_chw
914
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
1015

16+
from typing_extensions import Literal
17+
18+
19+
# Type shortcuts:
20+
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
21+
FillType = Union[int, float, Sequence[int], Sequence[float]]
22+
23+
24+
def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None:
25+
if isinstance(fill, dict):
26+
for key, value in fill.items():
27+
# Check key for type
28+
_check_fill_arg(value)
29+
else:
30+
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
31+
raise TypeError("Got inappropriate fill arg")
32+
33+
34+
def _setup_fill_arg(
35+
fill: Optional[Union[FillType, Dict[Type, FillType]]]
36+
) -> Union[Dict[Type, FillType], Dict[Type, None]]:
37+
_check_fill_arg(fill)
38+
39+
if isinstance(fill, dict):
40+
return fill
41+
42+
return defaultdict(lambda: fill) # type: ignore[return-value]
43+
44+
45+
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
46+
if not isinstance(padding, (numbers.Number, tuple, list)):
47+
raise TypeError("Got inappropriate padding arg")
48+
49+
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
50+
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
51+
52+
53+
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
54+
# https://github.com/pytorch/vision/issues/6250
55+
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
56+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
57+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
58+
1159

1260
def query_bounding_box(sample: Any) -> features.BoundingBox:
1361
flat_sample, _ = tree_flatten(sample)

0 commit comments

Comments
 (0)