Skip to content

Commit d4bf76c

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Persist torch.assert in aten graph (#100101)
This PR introduces a new operator called aten._assert_async.msg, which allows passing a tensor value and assertion message as inputs. As part of TorchDynamo, we're replacing the use of torch._assert with this new operator so that make_fx also knows how to handle assertions. This is subset of #98878, refer there for historic reviews. Pull Request resolved: #100101 Approved by: https://github.com/jansel
1 parent cef15ec commit d4bf76c

File tree

12 files changed

+100
-7
lines changed

12 files changed

+100
-7
lines changed

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ void _assert_async_cpu(const Tensor& self) {
405405
TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
406406
}
407407

408+
void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
409+
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
410+
}
411+
408412
// Sorting-based algorithm for isin(); used when the number of test elements is large.
409413
static void isin_sorting(
410414
const Tensor& elements,

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@
170170
CPU: _assert_async_cpu
171171
CUDA: _assert_async_cuda
172172

173+
- func: _assert_async.msg(Tensor self, str assert_msg) -> ()
174+
dispatch:
175+
CPU: _assert_async_msg_cpu
173176

174177
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
175178

test/dynamo/test_export.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,31 @@ def f(x):
25832583
):
25842584
gm, _ = torch._dynamo.export(f, torch.randn(5, 6), aten_graph=True)
25852585

2586+
@config.patch(assume_static_by_default=False)
2587+
def test_export_persist_assert(self):
2588+
def f(x):
2589+
assert x.shape[0] > 4, "Shape must be more than 4"
2590+
return x.cos() + x.sin()
2591+
2592+
gm, guard = torch._dynamo.export(
2593+
f, torch.randn(5, 4, 6), aten_graph=True, tracing_mode="symbolic"
2594+
)
2595+
2596+
def has_aten_op(gm, op):
2597+
for node in gm.graph.nodes:
2598+
if node.target == op:
2599+
return True
2600+
return False
2601+
2602+
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
2603+
2604+
gm.graph.eliminate_dead_code()
2605+
gm.recompile()
2606+
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
2607+
2608+
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
2609+
gm(torch.randn(3, 4, 5))
2610+
25862611
def test_access_class_method_from_user_class(self):
25872612
class A:
25882613
@classmethod

test/dynamo/test_repros.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,7 +2496,7 @@ def f(x):
24962496
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
24972497
self.assertTrue(same(exported(*args), f(*args)))
24982498

2499-
with self.assertRaisesRegex(AssertionError, ""):
2499+
with self.assertRaisesRegex(RuntimeError, "First dim need to be 3"):
25002500
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
25012501

25022502
def test_not_rewrite_assert_for_other_errors(self):
@@ -2521,7 +2521,7 @@ def f(x):
25212521
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
25222522
self.assertTrue(same(exported(*args), f(*args)))
25232523

2524-
with self.assertRaisesRegex(AssertionError, ""):
2524+
with self.assertRaisesRegex(RuntimeError, "assertion error"):
25252525
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
25262526

25272527
def test_rewrite_assert_with_non_string_msg(self):

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ aten::_amp_update_scale
3535
aten::_amp_update_scale.out
3636
aten::_amp_update_scale_
3737
aten::_assert_async
38+
aten::_assert_async.msg
3839
aten::_cdist_backward
3940
aten::_cdist_backward.out
4041
aten::_cdist_forward

test/inductor/test_torchinductor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,17 @@ def fn(a):
19011901
with self.assertRaisesRegex(RuntimeError, ""):
19021902
fn(torch.randn(1, 5))
19031903

1904+
def test_inductor_assert(self):
1905+
@torch._dynamo.optimize("inductor", dynamic=True)
1906+
def fn(a):
1907+
assert a.shape[0] >= 2 and a.shape[1] >= 4
1908+
return a.cos()
1909+
1910+
inp = torch.randn(2, 4, 6)
1911+
torch._dynamo.mark_dynamic(inp, 0)
1912+
torch._dynamo.mark_dynamic(inp, 1)
1913+
self.assertEqual(fn(inp), inp.cos())
1914+
19041915
def test_split(self):
19051916
def fn(a):
19061917
t = torch.split(a, 3, -1)

torch/_dynamo/symbolic_convert.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@
5555
GlobalWeakRefSource,
5656
LocalSource,
5757
)
58-
from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
58+
from .utils import (
59+
counters,
60+
get_fake_value,
61+
graph_break_dup_warning_checker,
62+
istype,
63+
proxy_args_kwargs,
64+
)
5965
from .variables.base import MutableLocal, typestr, VariableTracker
6066
from .variables.builder import VariableBuilder, wrap_fx_proxy
6167
from .variables.builtin import BuiltinVariable
@@ -249,12 +255,35 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
249255
self.jump(inst)
250256
return
251257

252-
# Manually insert torch._assert instead of python assert and jump over
258+
# TODO maybe should respect DtoH sync intention of users later??
259+
# Manually insert torch._assert_async instead of python assert and jump over
253260
# assert related instructions as we don't need them anymore.
261+
262+
# if we see Tensor as assert statement, no need to call scalar_tensor
263+
if isinstance(value, TensorVariable):
264+
self.output.create_proxy(
265+
"call_function",
266+
torch._assert_async,
267+
*proxy_args_kwargs((value, error_msg), {}),
268+
)
269+
self.jump(inst)
270+
return
271+
272+
scalar_to_tensor_proxy = self.output.create_proxy(
273+
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
274+
)
275+
276+
scalar_to_tensor = wrap_fx_proxy(
277+
self,
278+
scalar_to_tensor_proxy,
279+
example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
280+
**VariableTracker.propagate([value]),
281+
)
282+
254283
self.output.create_proxy(
255284
"call_function",
256-
torch._assert,
257-
*proxy_args_kwargs((value, error_msg), {}),
285+
torch._assert_async,
286+
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
258287
)
259288
self.jump(inst)
260289
return

torch/_inductor/decomposition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def _unsafe_view(self, size):
5454
return self.view(size)
5555

5656

57+
# TODO: for now, inductor doesn't handle asserts
58+
# because the condition is symbool -> tensor in the graph.
59+
@register_decomposition([aten._assert_async.msg])
60+
def assert_async_msg_decomp(tensor, msg):
61+
return
62+
63+
5764
@register_decomposition([aten.clamp])
5865
@pw_cast_for_opmath
5966
def clamp(x, min=None, max=None):

torch/_meta_registrations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,16 @@ def meta_angle_out(self, out):
295295
return out.copy_(torch.angle(self))
296296

297297

298+
@register_meta(aten._assert_async.default)
299+
def assert_async(val):
300+
return
301+
302+
303+
@register_meta(aten._assert_async.msg)
304+
def assert_async_meta(val, assert_msg):
305+
return
306+
307+
298308
# From aten/src/ATen/native/LinearAlgebraUtils.h
299309
def squareCheckInputs(self: Tensor, f_name: str):
300310
assert (

torch/fx/node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
_side_effectful_functions: Set[Callable] = {
3434
torch._assert,
35+
torch._assert_async,
36+
_ops.aten._assert_async.msg,
3537
_ops.aten.copy_.default,
3638
_ops.profiler._record_function_enter,
3739
_ops.profiler._record_function_enter_new,

torch/overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
387387
torch.argmin: lambda input: -1,
388388
torch.argsort: lambda input, dim=None: -1,
389389
torch.asin: lambda input, out=None: -1,
390-
torch._assert_async: lambda input: -1,
390+
torch._assert_async: lambda input, msg: -1,
391391
torch.arcsin: lambda input, out=None: -1,
392392
torch.asinh: lambda input, out=None: -1,
393393
torch.arcsinh: lambda input, out=None: -1,

torchgen/native_function_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
# All of these operators don't have any tensor like returns
5050
FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
5151
"_assert_async", # no return
52+
"_assert_async.msg", # no return
5253
"_dimI", # returns an int
5354
"_dimV", # returns an int
5455
"_has_same_storage_numel", # returns a boolean

0 commit comments

Comments
 (0)