-
Notifications
You must be signed in to change notification settings - Fork 7.1k
port RandomHorizontalFlip to prototype API #5563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
b24de4c
02f243b
28b2e2c
825c4f8
c9d5ca9
2b7cae6
6e84c7a
52dc452
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,9 +2,10 @@ | |||||||||
|
||||||||||
import pytest | ||||||||||
import torch | ||||||||||
from common_utils import assert_equal | ||||||||||
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels | ||||||||||
from torchvision.prototype import transforms, features | ||||||||||
from torchvision.transforms.functional import to_pil_image | ||||||||||
from torchvision.transforms.functional import to_pil_image, pil_to_tensor | ||||||||||
|
||||||||||
|
||||||||||
def make_vanilla_tensor_images(*args, **kwargs): | ||||||||||
|
@@ -66,10 +67,10 @@ def parametrize_from_transforms(*transforms): | |||||||||
class TestSmoke: | ||||||||||
@parametrize_from_transforms( | ||||||||||
transforms.RandomErasing(p=1.0), | ||||||||||
transforms.HorizontalFlip(), | ||||||||||
transforms.Resize([16, 16]), | ||||||||||
transforms.CenterCrop([16, 16]), | ||||||||||
transforms.ConvertImageDtype(), | ||||||||||
transforms.RandomHorizontalFlip(), | ||||||||||
) | ||||||||||
def test_common(self, transform, input): | ||||||||||
transform(input) | ||||||||||
|
@@ -152,3 +153,56 @@ def test_normalize(self, transform, input): | |||||||||
) | ||||||||||
def test_random_resized_crop(self, transform, input): | ||||||||||
transform(input) | ||||||||||
|
||||||||||
|
||||||||||
class TestRandomHorizontalFlip: | ||||||||||
def input_tensor(self, dtype: torch.dtype = torch.float32) -> torch.Tensor: | ||||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
return torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) | ||||||||||
|
||||||||||
def expected_tensor(self, dtype: torch.dtype = torch.float32) -> torch.Tensor: | ||||||||||
return torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we need specific values here. I think we should be good to have random image inputs and use vision/torchvision/transforms/functional_tensor.py Lines 126 to 129 in a8bde78
but this is also what I would use to produce a expected output if I didn't know the internals of the kernel. cc @NicolasHug for an opinion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used specific values to have full control over the creation of the expecting result. I preferred to not use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No strong opinion on my side for this specific transform. In some cases it's valuable to have a simple implementation of the transform that we're testing - it helps understanding what's going on with simpler code, and we can call it on arbitrary input. In the case of If we're confident that this hard-coded input covers all of what we might want to test against, then that's fine. |
||||||||||
|
||||||||||
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"]) | ||||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
def test_simple_tensor(self, p: float): | ||||||||||
input = self.input_tensor() | ||||||||||
|
||||||||||
actual = transforms.RandomHorizontalFlip(p=p)(input) | ||||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
expected = self.expected_tensor() if p == 1.0 else input | ||||||||||
assert_equal(expected, actual) | ||||||||||
|
||||||||||
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"]) | ||||||||||
def test_pil_image(self, p: float): | ||||||||||
input = self.input_tensor(dtype=torch.uint8) | ||||||||||
|
||||||||||
actual = transforms.RandomHorizontalFlip(p=p)(to_pil_image(input)) | ||||||||||
|
||||||||||
expected = self.expected_tensor(dtype=torch.uint8) if p == 1.0 else input | ||||||||||
assert_equal(expected, pil_to_tensor(actual)) | ||||||||||
|
||||||||||
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"]) | ||||||||||
def test_features_image(self, p: float): | ||||||||||
input = self.input_tensor() | ||||||||||
|
||||||||||
actual = transforms.RandomHorizontalFlip(p=p)(features.Image(input)) | ||||||||||
|
||||||||||
expected = self.expected_tensor() if p == 1.0 else input | ||||||||||
assert_equal(features.Image(expected), actual) | ||||||||||
|
||||||||||
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"]) | ||||||||||
def test_features_segmentation_mask(self, p: float): | ||||||||||
input = features.SegmentationMask(self.input_tensor()) | ||||||||||
|
||||||||||
actual = transforms.RandomHorizontalFlip(p=p)(input) | ||||||||||
|
||||||||||
expected = self.expected_tensor() if p == 1.0 else input | ||||||||||
assert_equal(features.SegmentationMask(expected), actual) | ||||||||||
|
||||||||||
@pytest.mark.parametrize("p", [0.0, 1.0], ids=["p=0", "p=1"]) | ||||||||||
def test_features_bounding_box(self, p: float): | ||||||||||
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) | ||||||||||
|
||||||||||
actual = transforms.RandomHorizontalFlip(p=p)(input) | ||||||||||
|
||||||||||
expected = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input | ||||||||||
assert_equal(features.BoundingBox.new_like(input, expected), actual) | ||||||||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.