Skip to content

Commit 8aaf3f1

Browse files
ezyangSacha Refshauge
authored and
Sacha Refshauge
committed
Revert D26815021: Revert D26744062: Add assert_async
Test Plan: revert-hammer Differential Revision: D26815021 Original commit changeset: 972eaafcdf14 fbshipit-source-id: e528260e1aa91df1873c73af00aa57addd671607
1 parent f56cabe commit 8aaf3f1

File tree

7 files changed

+110
-0
lines changed

7 files changed

+110
-0
lines changed

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ bool is_nonzero(const Tensor& self) {
193193
TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar");
194194
}
195195

196+
void assert_async_cpu(const Tensor& self) {
197+
TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
198+
}
199+
196200
namespace {
197201

198202
// DO NOT USE THIS -- it's just an implementation detail of wrapped_scalar tensor below.

aten/src/ATen/native/cuda/TensorCompare.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,26 @@ REGISTER_DISPATCH(where_kernel, &where_kernel_impl);
5959
REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
6060
REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
6161

62+
template <typename scalar_t>
63+
__global__ void assert_async_cuda_kernel(scalar_t* input) {
64+
CUDA_KERNEL_ASSERT(input[0] != 0);
65+
}
66+
67+
__global__ void assert_async_cuda_kernel(c10::complex<float>* input) {
68+
CUDA_KERNEL_ASSERT(input[0] != c10::complex<float>(0, 0));
69+
}
70+
__global__ void assert_async_cuda_kernel(c10::complex<double>* input) {
71+
CUDA_KERNEL_ASSERT(input[0] != c10::complex<double>(0, 0));
72+
}
73+
74+
void assert_async_cuda(const Tensor& self) {
75+
auto n = self.numel();
76+
TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
77+
TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous");
78+
auto stream = at::cuda::getCurrentCUDAStream();
79+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "assert_async_cuda", [&] {
80+
assert_async_cuda_kernel<<<1, 1, 0, stream>>>(self.data_ptr<scalar_t>());
81+
});
82+
}
83+
6284
}} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@
118118

119119
- func: align_tensors(Tensor[] tensors) -> Tensor[]
120120

121+
# Not assert because it's a keyword; not Assert because FX already
122+
# took that syntax
123+
# TODO: need to specify this is side-effectful somehow
124+
- func: assert_async(Tensor self) -> ()
125+
dispatch:
126+
CPU: assert_async_cpu
127+
CUDA: assert_async_cuda
128+
121129
- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
122130
variants: method
123131

test/test_cuda.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,37 @@ def worker(rank):
18251825
t2.start()
18261826
"""])
18271827

1828+
def test_cuda_assert_async(self):
1829+
with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
1830+
torch.assert_async(torch.tensor([], device="cuda"))
1831+
with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
1832+
torch.assert_async(torch.tensor([0, 0], device="cuda"))
1833+
1834+
torch.assert_async(torch.tensor(1, device="cuda"))
1835+
torch.assert_async(torch.tensor(0.1, device="cuda"))
1836+
torch.assert_async(torch.tensor(-0.1, device="cuda"))
1837+
torch.assert_async(torch.tensor(True, device="cuda"))
1838+
torch.assert_async(torch.tensor(0 + 0.1j, device="cuda"))
1839+
1840+
fail_stmts = [
1841+
"torch.assert_async(torch.tensor(0, device='cuda'))",
1842+
"torch.assert_async(torch.tensor(0.0, device='cuda'))",
1843+
"torch.assert_async(torch.tensor(False, device='cuda'))",
1844+
"torch.assert_async(torch.tensor(0+ 0 j, device='cuda'))",
1845+
]
1846+
1847+
import subprocess
1848+
for stmt in fail_stmts:
1849+
with self.subTest(stmt=stmt):
1850+
r = subprocess.call([sys.executable, '-c', f"""\
1851+
import torch
1852+
1853+
{stmt}
1854+
torch.cuda.synchronize()
1855+
"""])
1856+
self.assertTrue(r != 0)
1857+
1858+
18281859
def test_grad_scaling_unscale(self, dtype=torch.float):
18291860
inv_scale = torch.full((1,), 0.25, dtype=torch.float, device="cuda:0")
18301861
found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0")

test/test_torch.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,6 +2356,32 @@ def test_is_nonzero(self):
23562356
self.assertTrue(torch.tensor([1]).is_nonzero())
23572357
self.assertFalse(torch.tensor([[0]]).is_nonzero())
23582358
self.assertTrue(torch.tensor([[1]]).is_nonzero())
2359+
self.assertTrue(torch.tensor(0.1).is_nonzero())
2360+
self.assertTrue(torch.tensor(-0.1).is_nonzero())
2361+
self.assertFalse(torch.tensor(0.0).is_nonzero())
2362+
self.assertTrue(torch.tensor(True).is_nonzero())
2363+
self.assertFalse(torch.tensor(False).is_nonzero())
2364+
self.assertFalse(torch.tensor(0 + 0j).is_nonzero())
2365+
self.assertTrue(torch.tensor(0 + 0.1j).is_nonzero())
2366+
2367+
def test_assert_async(self):
2368+
with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
2369+
torch.assert_async(torch.tensor([]))
2370+
with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
2371+
torch.assert_async(torch.tensor([0, 0]))
2372+
with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
2373+
torch.assert_async(torch.tensor(0))
2374+
torch.assert_async(torch.tensor(1))
2375+
torch.assert_async(torch.tensor(0.1))
2376+
torch.assert_async(torch.tensor(-0.1))
2377+
with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
2378+
torch.assert_async(torch.tensor(0.0))
2379+
torch.assert_async(torch.tensor(True))
2380+
with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
2381+
torch.assert_async(torch.tensor(False))
2382+
torch.assert_async(torch.tensor(0 + 0.1j))
2383+
with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
2384+
torch.assert_async(torch.tensor(0 + 0j))
23592385

23602386
# NB: we must not be built with CUDA; if we are built with CUDA but no CUDA
23612387
# is available, we get a different error.

torch/_torch_docs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10565,6 +10565,24 @@ def merge_dicts(*dicts):
1056510565
device(type='cpu')
1056610566
""")
1056710567

10568+
add_docstr(torch.assert_async,
10569+
r"""
10570+
assert_async(tensor) -> void
10571+
10572+
Asynchronously assert that the contents of tensor are nonzero. For CPU tensors,
10573+
this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for
10574+
CUDA tensors, we DO NOT synchronize and you may only find out the assertion
10575+
failed at a later CUDA kernel launch. Asynchronous assertion can be helpful for
10576+
testing invariants in CUDA tensors without giving up performance. This function
10577+
is NOT intended to be used for regular error checking, as it will trash your CUDA
10578+
context if the assert fails (forcing you to restart your PyTorch process.)
10579+
10580+
Args:
10581+
tensor (Tensor): a one element tensor to test to see if it is nonzero. Zero
10582+
elements (including False for boolean tensors) cause an assertion failure
10583+
to be raised.
10584+
""")
10585+
1056810586
add_docstr(torch.searchsorted,
1056910587
r"""
1057010588
searchsorted(sorted_sequence, values, *, out_int32=False, right=False, out=None) -> Tensor

torch/overrides.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
272272
torch.argmin: lambda input: -1,
273273
torch.argsort: lambda input, dim=None: -1,
274274
torch.asin: lambda input, out=None: -1,
275+
torch.assert_async: lambda input: -1,
275276
torch.arcsin: lambda input, out=None: -1,
276277
torch.asinh: lambda input, out=None: -1,
277278
torch.arcsinh: lambda input, out=None: -1,

0 commit comments

Comments
 (0)