Skip to content

Commit 086bf45

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] prevent feature wrapping if the feature is not the primary operand (#6095)
Summary: * prevent feature wrapping if the feature is not the primary operand * explicitly add feature tests to CI Reviewed By: datumbox Differential Revision: D40138743 fbshipit-source-id: b5e523ce612b7380f8f9d11565f39b0fb6ef7b22
1 parent dc23572 commit 086bf45

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

.github/workflows/prototype-tests.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ jobs:
4343
id: setup
4444
run: exit 0
4545

46+
- name: Run prototype features tests
47+
shell: bash
48+
run: |
49+
pytest \
50+
--durations=20 \
51+
--cov=torchvision/prototype/features \
52+
--cov-report=term-missing \
53+
test/test_prototype_features*.py
54+
4655
- name: Run prototype datasets tests
4756
if: success() || ( failure() && steps.setup.conclusion == 'success' )
4857
shell: bash

test/test_prototype_features.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from torchvision.prototype import features
3+
4+
5+
def test_isinstance():
6+
assert isinstance(
7+
features.Label([0, 1, 0], categories=["foo", "bar"]),
8+
torch.Tensor,
9+
)
10+
11+
12+
def test_wrapping_no_copy():
13+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
14+
label = features.Label(tensor, categories=["foo", "bar"])
15+
16+
assert label.data_ptr() == tensor.data_ptr()
17+
18+
19+
def test_to_wrapping():
20+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
21+
label = features.Label(tensor, categories=["foo", "bar"])
22+
23+
label_to = label.to(torch.int32)
24+
25+
assert type(label_to) is features.Label
26+
assert label_to.dtype is torch.int32
27+
assert label_to.categories is label.categories
28+
29+
30+
def test_to_feature_reference():
31+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
32+
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
33+
34+
tensor_to = tensor.to(label)
35+
36+
assert type(tensor_to) is torch.Tensor
37+
assert tensor_to.dtype is torch.int32
38+
39+
40+
def test_clone_wrapping():
41+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
42+
label = features.Label(tensor, categories=["foo", "bar"])
43+
44+
label_clone = label.clone()
45+
46+
assert type(label_clone) is features.Label
47+
assert label_clone.data_ptr() != label.data_ptr()
48+
assert label_clone.categories is label.categories
49+
50+
51+
def test_other_op_no_wrapping():
52+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
53+
label = features.Label(tensor, categories=["foo", "bar"])
54+
55+
# any operation besides .to() and .clone() will do here
56+
output = label * 2
57+
58+
assert type(output) is torch.Tensor
59+
60+
61+
def test_new_like():
62+
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
63+
label = features.Label(tensor, categories=["foo", "bar"])
64+
65+
# any operation besides .to() and .clone() will do here
66+
output = label * 2
67+
68+
label_new = features.Label.new_like(label, output)
69+
70+
assert type(label_new) is features.Label
71+
assert label_new.data_ptr() == output.data_ptr()
72+
assert label_new.categories is label.categories

torchvision/prototype/features/_feature.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def __torch_function__(
8989
with DisableTorchFunction():
9090
output = func(*args, **kwargs)
9191

92+
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
93+
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
94+
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a
95+
# case.
96+
if not isinstance(args[0], cls):
97+
return output
98+
9299
if func is torch.Tensor.clone:
93100
return cls.new_like(args[0], output)
94101
elif func is torch.Tensor.to:

0 commit comments

Comments
 (0)