Skip to content

Commit dcd1e42

Browse files
Fix RandomResizedCrop scale & ratio argument. (#8944)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 05decc3 commit dcd1e42

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

test/test_transforms_v2.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3758,12 +3758,18 @@ def test_transform_errors_warnings(self):
37583758
with pytest.raises(ValueError, match="provide only two dimensions"):
37593759
transforms.RandomResizedCrop(size=(1, 2, 3))
37603760

3761-
with pytest.raises(TypeError, match="Scale should be a sequence"):
3761+
with pytest.raises(TypeError, match="Scale should be a sequence of two floats."):
37623762
transforms.RandomResizedCrop(size=self.INPUT_SIZE, scale=123)
37633763

3764-
with pytest.raises(TypeError, match="Ratio should be a sequence"):
3764+
with pytest.raises(TypeError, match="Ratio should be a sequence of two floats."):
37653765
transforms.RandomResizedCrop(size=self.INPUT_SIZE, ratio=123)
37663766

3767+
with pytest.raises(TypeError, match="Ratio should be a sequence of two floats."):
3768+
transforms.RandomResizedCrop(size=self.INPUT_SIZE, ratio=[1, 2, 3])
3769+
3770+
with pytest.raises(TypeError, match="Scale should be a sequence of two floats."):
3771+
transforms.RandomResizedCrop(size=self.INPUT_SIZE, scale=[1, 2, 3])
3772+
37673773
for param in ["scale", "ratio"]:
37683774
with pytest.warns(match="Scale and ratio should be of kind"):
37693775
transforms.RandomResizedCrop(size=self.INPUT_SIZE, **{param: [1, 0]})

torchvision/transforms/v2/_geometry.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,10 @@ def __init__(
254254
super().__init__()
255255
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
256256

257-
if not isinstance(scale, Sequence):
258-
raise TypeError("Scale should be a sequence")
259-
if not isinstance(ratio, Sequence):
260-
raise TypeError("Ratio should be a sequence")
257+
if not isinstance(scale, Sequence) or len(scale) != 2:
258+
raise TypeError("Scale should be a sequence of two floats.")
259+
if not isinstance(ratio, Sequence) or len(ratio) != 2:
260+
raise TypeError("Ratio should be a sequence of two floats.")
261261
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
262262
warnings.warn("Scale and ratio should be of kind (min, max)")
263263

0 commit comments

Comments
 (0)