Skip to content

Commit de07666

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] [proto] Argument fill can accept dict of base types (#6586)
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]> Reviewed By: jdsgomes Differential Revision: D39543288 fbshipit-source-id: 5763f1d98152303251ff3d3dd60eac0ebaeada2a
1 parent 6cb70f8 commit de07666

File tree

4 files changed

+93
-17
lines changed

4 files changed

+93
-17
lines changed

test/test_prototype_transforms.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,28 @@ def test__transform(self, padding, fill, padding_mode, mocker):
378378

379379
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
380380

381+
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
382+
def test__transform_image_mask(self, fill, mocker):
383+
transform = transforms.Pad(1, fill=fill, padding_mode="constant")
384+
385+
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
386+
image = features.Image(torch.rand(3, 32, 32))
387+
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
388+
inpt = [image, mask]
389+
_ = transform(inpt)
390+
391+
if isinstance(fill, int):
392+
calls = [
393+
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
394+
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
395+
]
396+
else:
397+
calls = [
398+
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
399+
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
400+
]
401+
fn.assert_has_calls(calls)
402+
381403

382404
class TestRandomZoomOut:
383405
def test_assertions(self):
@@ -400,7 +422,6 @@ def test__get_params(self, fill, side_range, mocker):
400422

401423
params = transform._get_params(image)
402424

403-
assert params["fill"] == fill
404425
assert len(params["padding"]) == 4
405426
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
406427
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
@@ -426,7 +447,34 @@ def test__transform(self, fill, side_range, mocker):
426447
torch.rand(1) # random apply changes random state
427448
params = transform._get_params(inpt)
428449

429-
fn.assert_called_once_with(inpt, **params)
450+
fn.assert_called_once_with(inpt, **params, fill=fill)
451+
452+
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
453+
def test__transform_image_mask(self, fill, mocker):
454+
transform = transforms.RandomZoomOut(fill=fill, p=1.0)
455+
456+
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
457+
image = features.Image(torch.rand(3, 32, 32))
458+
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
459+
inpt = [image, mask]
460+
461+
torch.manual_seed(12)
462+
_ = transform(inpt)
463+
torch.manual_seed(12)
464+
torch.rand(1) # random apply changes random state
465+
params = transform._get_params(inpt)
466+
467+
if isinstance(fill, int):
468+
calls = [
469+
mocker.call(image, **params, fill=fill),
470+
mocker.call(mask, **params, fill=0),
471+
]
472+
else:
473+
calls = [
474+
mocker.call(image, **params, fill=fill[type(image)]),
475+
mocker.call(mask, **params, fill=fill[type(mask)]),
476+
]
477+
fn.assert_has_calls(calls)
430478

431479

432480
class TestRandomRotation:

torchvision/prototype/features/_mask.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,14 @@ def pad(
5858
if not isinstance(padding, int):
5959
padding = list(padding)
6060

61-
output = self._F.pad_mask(self, padding, padding_mode=padding_mode)
61+
if isinstance(fill, (int, float)) or fill is None:
62+
if fill is None:
63+
fill = 0
64+
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
65+
else:
66+
# Let's raise an error for vector fill on masks
67+
raise ValueError("Non-scalar fill value is not supported")
68+
6269
return Mask.new_like(self, output)
6370

6471
def rotate(

torchvision/prototype/transforms/_geometry.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
22
import numbers
33
import warnings
4-
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
4+
from collections import defaultdict
5+
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
56

67
import PIL.Image
78
import torch
@@ -16,6 +17,7 @@
1617

1718

1819
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
20+
FillType = Union[int, float, Sequence[int], Sequence[float]]
1921

2022

2123
class RandomHorizontalFlip(_RandomApplyTransform):
@@ -196,9 +198,21 @@ def forward(self, *inputs: Any) -> Any:
196198
return super().forward(*inputs)
197199

198200

199-
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
200-
if not isinstance(fill, (numbers.Number, tuple, list)):
201-
raise TypeError("Got inappropriate fill arg")
201+
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
202+
if isinstance(fill, dict):
203+
for key, value in fill.items():
204+
# Check key for type
205+
_check_fill_arg(value)
206+
else:
207+
if not isinstance(fill, (numbers.Number, tuple, list)):
208+
raise TypeError("Got inappropriate fill arg")
209+
210+
211+
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
212+
if isinstance(fill, dict):
213+
return fill
214+
else:
215+
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]
202216

203217

204218
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
@@ -220,7 +234,7 @@ class Pad(Transform):
220234
def __init__(
221235
self,
222236
padding: Union[int, Sequence[int]],
223-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
237+
fill: Union[FillType, Dict[Type, FillType]] = 0,
224238
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
225239
) -> None:
226240
super().__init__()
@@ -230,24 +244,25 @@ def __init__(
230244
_check_padding_mode_arg(padding_mode)
231245

232246
self.padding = padding
233-
self.fill = fill
247+
self.fill = _setup_fill_arg(fill)
234248
self.padding_mode = padding_mode
235249

236250
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
237-
return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
251+
fill = self.fill[type(inpt)]
252+
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
238253

239254

240255
class RandomZoomOut(_RandomApplyTransform):
241256
def __init__(
242257
self,
243-
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
258+
fill: Union[FillType, Dict[Type, FillType]] = 0,
244259
side_range: Sequence[float] = (1.0, 4.0),
245260
p: float = 0.5,
246261
) -> None:
247262
super().__init__(p=p)
248263

249264
_check_fill_arg(fill)
250-
self.fill = fill
265+
self.fill = _setup_fill_arg(fill)
251266

252267
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
253268

@@ -256,7 +271,7 @@ def __init__(
256271
raise ValueError(f"Invalid canvas side range provided {side_range}.")
257272

258273
def _get_params(self, sample: Any) -> Dict[str, Any]:
259-
orig_c, orig_h, orig_w = query_chw(sample)
274+
_, orig_h, orig_w = query_chw(sample)
260275

261276
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
262277
canvas_width = int(orig_w * r)
@@ -269,10 +284,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
269284
bottom = canvas_height - (top + orig_h)
270285
padding = [left, top, right, bottom]
271286

272-
return dict(padding=padding, fill=self.fill)
287+
return dict(padding=padding)
273288

274289
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
275-
return F.pad(inpt, **params)
290+
fill = self.fill[type(inpt)]
291+
return F.pad(inpt, **params, fill=fill)
276292

277293

278294
class RandomRotation(Transform):

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,14 +635,19 @@ def _pad_with_vector_fill(
635635
return output
636636

637637

638-
def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor:
638+
def pad_mask(
639+
mask: torch.Tensor,
640+
padding: Union[int, List[int]],
641+
padding_mode: str = "constant",
642+
fill: Optional[Union[int, float]] = 0,
643+
) -> torch.Tensor:
639644
if mask.ndim < 3:
640645
mask = mask.unsqueeze(0)
641646
needs_squeeze = True
642647
else:
643648
needs_squeeze = False
644649

645-
output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode)
650+
output = pad_image_tensor(img=mask, padding=padding, fill=fill, padding_mode=padding_mode)
646651

647652
if needs_squeeze:
648653
output = output.squeeze(0)

0 commit comments

Comments
 (0)