diff --git a/.github/workflows/prototype-tests.yml b/.github/workflows/prototype-tests.yml index e9832860c40..5e9ca360d08 100644 --- a/.github/workflows/prototype-tests.yml +++ b/.github/workflows/prototype-tests.yml @@ -43,6 +43,15 @@ jobs: id: setup run: exit 0 + - name: Run prototype features tests + shell: bash + run: | + pytest \ + --durations=20 \ + --cov=torchvision/prototype/features \ + --cov-report=term-missing \ + test/test_prototype_features*.py + - name: Run prototype datasets tests if: success() || ( failure() && steps.setup.conclusion == 'success' ) shell: bash diff --git a/test/test_prototype_features.py b/test/test_prototype_features.py new file mode 100644 index 00000000000..65f30e9d569 --- /dev/null +++ b/test/test_prototype_features.py @@ -0,0 +1,72 @@ +import torch +from torchvision.prototype import features + + +def test_isinstance(): + assert isinstance( + features.Label([0, 1, 0], categories=["foo", "bar"]), + torch.Tensor, + ) + + +def test_wrapping_no_copy(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + assert label.data_ptr() == tensor.data_ptr() + + +def test_to_wrapping(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + label_to = label.to(torch.int32) + + assert type(label_to) is features.Label + assert label_to.dtype is torch.int32 + assert label_to.categories is label.categories + + +def test_to_feature_reference(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) + + tensor_to = tensor.to(label) + + assert type(tensor_to) is torch.Tensor + assert tensor_to.dtype is torch.int32 + + +def test_clone_wrapping(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + label_clone = label.clone() + + assert type(label_clone) is features.Label + assert label_clone.data_ptr() != label.data_ptr() + assert label_clone.categories is label.categories + + +def test_other_op_no_wrapping(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + # any operation besides .to() and .clone() will do here + output = label * 2 + + assert type(output) is torch.Tensor + + +def test_new_like(): + tensor = torch.tensor([0, 1, 0], dtype=torch.int64) + label = features.Label(tensor, categories=["foo", "bar"]) + + # any operation besides .to() and .clone() will do here + output = label * 2 + + label_new = features.Label.new_like(label, output) + + assert type(label_new) is features.Label + assert label_new.data_ptr() == output.data_ptr() + assert label_new.categories is label.categories diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index b3f2172895d..a2005f4d5c5 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -89,6 +89,13 @@ def __torch_function__( with DisableTorchFunction(): output = func(*args, **kwargs) + # The __torch_function__ protocol will invoke this method on all types involved in the computation by walking + # the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke + # `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a + # case. + if not isinstance(args[0], cls): + return output + if func is torch.Tensor.clone: return cls.new_like(args[0], output) elif func is torch.Tensor.to: