diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index a039b63f17e..edfb6d06442 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -29,7 +29,7 @@ "needs": "linux.16xlarge.nvidia.gpu" }, "intermediate_source/torchvision_tutorial.py": { - "needs": "linux.g5.4xlarge.nvidia.gpu", + "needs": "linux.g5.4xlarge.nvidia.gpu", "_comment": "does not require a5g but needs to run before gpu_quantization_torchao_tutorial.py." }, "advanced_source/coding_ddpg.py": { @@ -39,6 +39,9 @@ "intermediate_source/torch_compile_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, + "intermediate_source/torch_export_tutorial.py": { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "intermediate_source/scaled_dot_product_attention_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 5ca573cb398..cf578af0108 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -11,11 +11,12 @@ # .. warning:: # # ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility -# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.2. +# breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.3. # # :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into # standardized model representations, intended -# to be run on different (i.e. Python-less) environments. +# to be run on different (i.e. Python-less) environments. The official +# documentation can be found `here `__. # # In this tutorial, you will learn how to use :func:`torch.export` to extract # ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs. @@ -71,7 +72,7 @@ def forward(self, x, y): mod = MyModule() exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100))) print(type(exported_mod)) -print(exported_mod(torch.randn(8, 100), torch.randn(8, 100))) +print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100))) ###################################################################### @@ -100,7 +101,7 @@ def forward(self, x, y): # Other attributes of interest in ``ExportedProgram`` include: # # - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph. -# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later +# - ``range_constraints`` -- constraints, covered later print(exported_mod.graph_signature) @@ -123,54 +124,58 @@ def forward(self, x, y): # # - data-dependent control flow -def bad1(x): - if x.sum() > 0: - return torch.sin(x) - return torch.cos(x) +class Bad1(torch.nn.Module): + def forward(self, x): + if x.sum() > 0: + return torch.sin(x) + return torch.cos(x) import traceback as tb try: - export(bad1, (torch.randn(3, 3),)) + export(Bad1(), (torch.randn(3, 3),)) except Exception: tb.print_exc() ###################################################################### # - accessing tensor data with ``.data`` -def bad2(x): - x.data[0, 0] = 3 - return x +class Bad2(torch.nn.Module): + def forward(self, x): + x.data[0, 0] = 3 + return x try: - export(bad2, (torch.randn(3, 3),)) + export(Bad2(), (torch.randn(3, 3),)) except Exception: tb.print_exc() ###################################################################### # - calling unsupported functions (such as many built-in functions) -def bad3(x): - x = x + 1 - return x + id(x) +class Bad3(torch.nn.Module): + def forward(self, x): + x = x + 1 + return x + id(x) try: - export(bad3, (torch.randn(3, 3),)) + export(Bad3(), (torch.randn(3, 3),)) except Exception: tb.print_exc() ###################################################################### # - unsupported Python language features (e.g. throwing exceptions, match statements) -def bad4(x): - try: - x = x + 1 - raise RuntimeError("bad") - except: - x = x + 2 - return x +class Bad4(torch.nn.Module): + def forward(self, x): + try: + x = x + 1 + raise RuntimeError("bad") + except: + x = x + 2 + return x try: - export(bad4, (torch.randn(3, 3),)) + export(Bad4(), (torch.randn(3, 3),)) except Exception: tb.print_exc() @@ -188,16 +193,17 @@ def bad4(x): from functorch.experimental.control_flow import cond -def bad1_fixed(x): - def true_fn(x): - return torch.sin(x) - def false_fn(x): - return torch.cos(x) - return cond(x.sum() > 0, true_fn, false_fn, [x]) +class Bad1Fixed(torch.nn.Module): + def forward(self, x): + def true_fn(x): + return torch.sin(x) + def false_fn(x): + return torch.cos(x) + return cond(x.sum() > 0, true_fn, false_fn, [x]) -exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),)) -print(exported_bad1_fixed(torch.ones(3, 3))) -print(exported_bad1_fixed(-torch.ones(3, 3))) +exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),)) +print(exported_bad1_fixed.module()(torch.ones(3, 3))) +print(exported_bad1_fixed.module()(-torch.ones(3, 3))) ###################################################################### # There are limitations to ``cond`` that one should be aware of: @@ -255,7 +261,7 @@ def forward(self, x, y): exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100))) try: - exported_mod2(torch.randn(10, 100), torch.randn(10, 100)) + exported_mod2.module()(torch.randn(10, 100), torch.randn(10, 100)) except Exception: tb.print_exc() @@ -286,9 +292,10 @@ def forward(self, x, y): inp1 = torch.randn(10, 10, 2) -def dynamic_shapes_example1(x): - x = x[:, 2:] - return torch.relu(x) +class DynamicShapesExample1(torch.nn.Module): + def forward(self, x): + x = x[:, 2:] + return torch.relu(x) inp1_dim0 = Dim("inp1_dim0") inp1_dim1 = Dim("inp1_dim1", min=4, max=18) @@ -296,22 +303,22 @@ def dynamic_shapes_example1(x): "x": {0: inp1_dim0, 1: inp1_dim1}, } -exported_dynamic_shapes_example1 = export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1) +exported_dynamic_shapes_example1 = export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1) -print(exported_dynamic_shapes_example1(torch.randn(5, 5, 2))) +print(exported_dynamic_shapes_example1.module()(torch.randn(5, 5, 2))) try: - exported_dynamic_shapes_example1(torch.randn(8, 1, 2)) + exported_dynamic_shapes_example1.module()(torch.randn(8, 1, 2)) except Exception: tb.print_exc() try: - exported_dynamic_shapes_example1(torch.randn(8, 20, 2)) + exported_dynamic_shapes_example1.module()(torch.randn(8, 20, 2)) except Exception: tb.print_exc() try: - exported_dynamic_shapes_example1(torch.randn(8, 8, 3)) + exported_dynamic_shapes_example1.module()(torch.randn(8, 8, 3)) except Exception: tb.print_exc() @@ -325,7 +332,7 @@ def dynamic_shapes_example1(x): } try: - export(dynamic_shapes_example1, (inp1,), dynamic_shapes=dynamic_shapes1_bad) + export(DynamicShapesExample1(), (inp1,), dynamic_shapes=dynamic_shapes1_bad) except Exception: tb.print_exc() @@ -336,8 +343,9 @@ def dynamic_shapes_example1(x): inp2 = torch.randn(4, 8) inp3 = torch.randn(8, 2) -def dynamic_shapes_example2(x, y): - return x @ y +class DynamicShapesExample2(torch.nn.Module): + def forward(self, x, y): + return x @ y inp2_dim0 = Dim("inp2_dim0") inner_dim = Dim("inner_dim") @@ -348,12 +356,12 @@ def dynamic_shapes_example2(x, y): "y": {0: inner_dim, 1: inp3_dim1}, } -exported_dynamic_shapes_example2 = export(dynamic_shapes_example2, (inp2, inp3), dynamic_shapes=dynamic_shapes2) +exported_dynamic_shapes_example2 = export(DynamicShapesExample2(), (inp2, inp3), dynamic_shapes=dynamic_shapes2) -print(exported_dynamic_shapes_example2(torch.randn(2, 16), torch.randn(16, 4))) +print(exported_dynamic_shapes_example2.module()(torch.randn(2, 16), torch.randn(16, 4))) try: - exported_dynamic_shapes_example2(torch.randn(4, 8), torch.randn(4, 2)) + exported_dynamic_shapes_example2.module()(torch.randn(4, 8), torch.randn(4, 2)) except Exception: tb.print_exc() @@ -367,10 +375,11 @@ def dynamic_shapes_example2(x, y): inp4 = torch.randn(8, 16) inp5 = torch.randn(16, 32) -def dynamic_shapes_example3(x, y): - if x.shape[0] <= 16: - return x @ y[:, :16] - return y +class DynamicShapesExample3(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] <= 16: + return x @ y[:, :16] + return y dynamic_shapes3 = { "x": {i: Dim(f"inp4_dim{i}") for i in range(inp4.dim())}, @@ -378,7 +387,7 @@ def dynamic_shapes_example3(x, y): } try: - export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3) + export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3) except Exception: tb.print_exc() @@ -400,8 +409,8 @@ def suggested_fixes(): } dynamic_shapes3_fixed = suggested_fixes() -exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) -print(exported_dynamic_shapes_example3(torch.randn(4, 32), torch.randn(32, 64))) +exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) +print(exported_dynamic_shapes_example3.module()(torch.randn(4, 32), torch.randn(32, 64))) ###################################################################### # Note that in the example above, because we constrained the value of ``x.shape[0]`` in @@ -414,18 +423,16 @@ def suggested_fixes(): import logging torch._logging.set_logs(dynamic=logging.INFO, dynamo=logging.INFO) -exported_dynamic_shapes_example3 = export(dynamic_shapes_example3, (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) +exported_dynamic_shapes_example3 = export(DynamicShapesExample3(), (inp4, inp5), dynamic_shapes=dynamic_shapes3_fixed) # reset to previous values torch._logging.set_logs(dynamic=logging.WARNING, dynamo=logging.WARNING) ###################################################################### -# We can view an ``ExportedProgram``'s constraints using the ``range_constraints`` and -# ``equality_constraints`` attributes. The logging above reveals what the symbols ``s0, s1, ...`` -# represent. +# We can view an ``ExportedProgram``'s symbolic shape ranges using the +# ``range_constraints`` field. print(exported_dynamic_shapes_example3.range_constraints) -print(exported_dynamic_shapes_example3.equality_constraints) ###################################################################### # Custom Ops @@ -438,7 +445,7 @@ def suggested_fixes(): # - Define the custom op using ``torch.library`` (`reference `__) # as with any other custom op -from torch.library import Library, impl +from torch.library import Library, impl, impl_abstract m = Library("my_custom_library", "DEF") @@ -453,25 +460,26 @@ def custom_op(x): # - Define a ``"Meta"`` implementation of the custom op that returns an empty # tensor with the same shape as the expected output -@impl(m, "custom_op", "Meta") +@impl_abstract("my_custom_library::custom_op") def custom_op_meta(x): return torch.empty_like(x) ###################################################################### # - Call the custom op from the code you want to export using ``torch.ops`` -def custom_op_example(x): - x = torch.sin(x) - x = torch.ops.my_custom_library.custom_op(x) - x = torch.cos(x) - return x +class CustomOpExample(torch.nn.Module): + def forward(self, x): + x = torch.sin(x) + x = torch.ops.my_custom_library.custom_op(x) + x = torch.cos(x) + return x ###################################################################### # - Export the code as before -exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),)) +exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),)) exported_custom_op_example.graph_module.print_readable() -print(exported_custom_op_example(torch.randn(3, 3))) +print(exported_custom_op_example.module()(torch.randn(3, 3))) ###################################################################### # Note in the above outputs that the custom op is included in the exported graph. @@ -606,6 +614,51 @@ def cond_predicate(x): # ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach # out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``. +###################################################################### +# Running the Exported Program +# ---------------------------- +# +# As ``torch.export`` is only a graph capturing mechanism, calling the artifact +# produced by ``torch.export`` eagerly will be equivalent to running the eager +# module. To optimize the execution of the Exported Program, we can pass this +# exported artifact to backends such as Inductor through ``torch.compile``, +# `AOTInductor `__, +# or `TensorRT `__. + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + +inp = torch.randn(2, 3, device="cuda") +m = M().to(device="cuda") +ep = torch.export.export(m, (inp,)) + +# Run it eagerly +res = ep.module()(inp) +print(res) + +# Run it with torch.compile +res = torch.compile(ep.module(), backend="inductor")(inp) +print(res) + +import torch._export +import torch._inductor + +# Note: these APIs are subject to change +# Compile the exported program to a .so using AOTInductor +with torch.no_grad(): + so_path = torch._inductor.aot_compile(ep.module(), [inp]) +# Load and run the .so file in Python. +# To load and run it in a C++ environment, see: +# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html +res = torch._export.aot_load(so_path, device="cuda")(inp) +print(res) + ###################################################################### # Conclusion # ----------