1
1
import math
2
- import numbers
3
- from typing import Any , Callable , cast , Dict , List , Optional , Sequence , Tuple , Type , TypeVar , Union
2
+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Type , TypeVar , Union
4
3
5
4
import PIL .Image
6
5
import torch
10
9
from torchvision .prototype .transforms import AutoAugmentPolicy , functional as F , InterpolationMode , Transform
11
10
from torchvision .prototype .transforms .functional ._meta import get_chw
12
11
13
- from ._utils import _isinstance
12
+ from ._utils import _isinstance , _setup_fill_arg , FillType
14
13
15
14
K = TypeVar ("K" )
16
15
V = TypeVar ("V" )
@@ -21,14 +20,11 @@ def __init__(
21
20
self ,
22
21
* ,
23
22
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
24
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
23
+ fill : Optional [ Union [FillType , Dict [ Type , FillType ]]] = None ,
25
24
) -> None :
26
25
super ().__init__ ()
27
26
self .interpolation = interpolation
28
-
29
- if not isinstance (fill , (numbers .Number , tuple , list )):
30
- raise TypeError ("Got inappropriate fill arg" )
31
- self .fill = fill
27
+ self .fill = _setup_fill_arg (fill )
32
28
33
29
def _get_random_item (self , dct : Dict [K , V ]) -> Tuple [K , V ]:
34
30
keys = tuple (dct .keys ())
@@ -63,19 +59,14 @@ def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
63
59
64
60
def _apply_image_transform (
65
61
self ,
66
- image : Any ,
62
+ image : Union [ torch . Tensor , PIL . Image . Image , features . Image ] ,
67
63
transform_id : str ,
68
64
magnitude : float ,
69
65
interpolation : InterpolationMode ,
70
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]],
66
+ fill : Union [Dict [ Type , FillType ], Dict [ Type , None ]],
71
67
) -> Any :
72
-
73
- # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
74
- # So, we have to put fill as None if fill == 0
75
- # This is due to BC with stable API which has fill = None by default
76
- fill_ = F ._geometry ._convert_fill_arg (fill )
77
- if isinstance (fill , int ) and fill == 0 :
78
- fill_ = None
68
+ fill_ = fill [type (image )]
69
+ fill_ = F ._geometry ._convert_fill_arg (fill_ )
79
70
80
71
if transform_id == "Identity" :
81
72
return image
@@ -186,7 +177,7 @@ def __init__(
186
177
self ,
187
178
policy : AutoAugmentPolicy = AutoAugmentPolicy .IMAGENET ,
188
179
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
189
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
180
+ fill : Optional [ Union [FillType , Dict [ Type , FillType ]]] = None ,
190
181
) -> None :
191
182
super ().__init__ (interpolation = interpolation , fill = fill )
192
183
self .policy = policy
@@ -286,7 +277,7 @@ def forward(self, *inputs: Any) -> Any:
286
277
sample = inputs if len (inputs ) > 1 else inputs [0 ]
287
278
288
279
id , image = self ._extract_image (sample )
289
- num_channels , height , width = get_chw (image )
280
+ _ , height , width = get_chw (image )
290
281
291
282
policy = self ._policies [int (torch .randint (len (self ._policies ), ()))]
292
283
@@ -346,7 +337,7 @@ def __init__(
346
337
magnitude : int = 9 ,
347
338
num_magnitude_bins : int = 31 ,
348
339
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
349
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
340
+ fill : Optional [ Union [FillType , Dict [ Type , FillType ]]] = None ,
350
341
) -> None :
351
342
super ().__init__ (interpolation = interpolation , fill = fill )
352
343
self .num_ops = num_ops
@@ -402,7 +393,7 @@ def __init__(
402
393
self ,
403
394
num_magnitude_bins : int = 31 ,
404
395
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
405
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
396
+ fill : Optional [ Union [FillType , Dict [ Type , FillType ]]] = None ,
406
397
):
407
398
super ().__init__ (interpolation = interpolation , fill = fill )
408
399
self .num_magnitude_bins = num_magnitude_bins
@@ -462,7 +453,7 @@ def __init__(
462
453
alpha : float = 1.0 ,
463
454
all_ops : bool = True ,
464
455
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
465
- fill : Union [int , float , Sequence [ int ], Sequence [ float ]] = 0 ,
456
+ fill : Optional [ Union [FillType , Dict [ Type , FillType ]]] = None ,
466
457
) -> None :
467
458
super ().__init__ (interpolation = interpolation , fill = fill )
468
459
self ._PARAMETER_MAX = 10
0 commit comments