-
Notifications
You must be signed in to change notification settings - Fork 7.1k
prevent feature wrapping if the feature is not the primary operand #6095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
from torchvision.prototype import features | ||
|
||
|
||
def test_isinstance(): | ||
assert isinstance( | ||
features.Label([0, 1, 0], categories=["foo", "bar"]), | ||
torch.Tensor, | ||
) | ||
|
||
|
||
def test_wrapping_no_copy(): | ||
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) | ||
label = features.Label(tensor, categories=["foo", "bar"]) | ||
|
||
assert label.data_ptr() == tensor.data_ptr() | ||
|
||
|
||
def test_to_wrapping(): | ||
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) | ||
label = features.Label(tensor, categories=["foo", "bar"]) | ||
|
||
label_to = label.to(torch.int32) | ||
|
||
assert type(label_to) is features.Label | ||
assert label_to.dtype is torch.int32 | ||
assert label_to.categories is label.categories | ||
|
||
|
||
def test_to_feature_reference(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the test case that makes sure we actually fix #6094. |
||
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) | ||
label = features.Label(tensor, categories=["foo", "bar"]).to(torch.int32) | ||
|
||
tensor_to = tensor.to(label) | ||
|
||
assert type(tensor_to) is torch.Tensor | ||
assert tensor_to.dtype is torch.int32 | ||
|
||
|
||
def test_clone_wrapping(): | ||
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) | ||
label = features.Label(tensor, categories=["foo", "bar"]) | ||
|
||
label_clone = label.clone() | ||
|
||
assert type(label_clone) is features.Label | ||
assert label_clone.data_ptr() != label.data_ptr() | ||
assert label_clone.categories is label.categories | ||
|
||
|
||
def test_other_op_no_wrapping(): | ||
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) | ||
label = features.Label(tensor, categories=["foo", "bar"]) | ||
|
||
# any operation besides .to() and .clone() will do here | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output = label * 2 | ||
|
||
assert type(output) is torch.Tensor | ||
|
||
|
||
def test_new_like(): | ||
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) | ||
label = features.Label(tensor, categories=["foo", "bar"]) | ||
|
||
# any operation besides .to() and .clone() will do here | ||
output = label * 2 | ||
|
||
label_new = features.Label.new_like(label, output) | ||
|
||
assert type(label_new) is features.Label | ||
assert label_new.data_ptr() == output.data_ptr() | ||
assert label_new.categories is label.categories |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,13 @@ def __torch_function__( | |
with DisableTorchFunction(): | ||
output = func(*args, **kwargs) | ||
|
||
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking | ||
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke | ||
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a | ||
# case. | ||
if not isinstance(args[0], cls): | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return output | ||
Comment on lines
+96
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although the behavior should be kept for all ops where we might wrap, right now this only applies to Thus, we could also merge this into the |
||
|
||
if func is torch.Tensor.clone: | ||
return cls.new_like(args[0], output) | ||
elif func is torch.Tensor.to: | ||
|
Uh oh!
There was an error while loading. Please reload this page.