diff --git a/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst new file mode 100644 index 0000000000..d64f81d4aa --- /dev/null +++ b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst @@ -0,0 +1,109 @@ +.. _writing_dynamo_aten_lowering_passes: + +Writing Dynamo ATen Lowering Passes +=================== + +Basics of a Lowering Pass +------------ + +ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object. + +Lowering Pass Requirements +------------ + +An ATen lowering pass function in Torch-TRT must satisfy two requirements: +- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule` +- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation + +See this link for information on `Graph Manipulations `_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines. + +Example Lowering Pass +------------ + +.. code-block:: python + + def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Repair scenarios where inputs are also outputs of the graph + + TRT does not allow such cases, so we insert a clone (identity) layer + """ + modified_graph = False + + # Extract graph placeholder Tensors + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ) + ] + + for placeholder in placeholders: + # If any placeholder has any users which are direct graph outputs + if len(placeholder.users) >= 1 and any( + user.op == "output" for user in placeholder.users + ): + modified_graph = True + + # Get direct graph outputs which are direct uses of placeholders + direct_outputs = [user for user in placeholder.users if user.op == "output"] + + # Insert clone node for placeholder to ensure + # placeholder is not a direct output + with gm.graph.inserting_after(placeholder): + cloned_placeholder = gm.graph.call_function( + torch.ops.aten.clone.default, + args=(placeholder,), + ) + + # Replace placeholder as output with cloned version + for output in direct_outputs: + output.replace_input_with(placeholder, cloned_placeholder) + + # If the graph was modified, clean up the graph and ensure it is up-to-date + if modified_graph: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") + + return gm + + +Registering Lowering Passes +---------------------- + +Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted. + +For instance, to insert the pass at the default location (end of the list), the following code can be used: + +.. code-block:: python + + @_aten_lowering_pass + def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + ... + +Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used: + +.. code-block:: python + + @_aten_lowering_pass(index=0) + def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + ... + +There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index. + +.. code-block:: python + + # Print all lowering passes in the list + print(dump_lowering_passes()) + + # Apply lowering passes to a GraphModule + apply_lowering_passes(graph_module) + + # Remove the lowering pass at index 1 + _remove_lowering_pass(index=1) + +**Note:** The above APIs are subject to change, as the lowering pass system evolves. diff --git a/docsrc/index.rst b/docsrc/index.rst index eee62bc2f7..ded3b99c9d 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -128,6 +128,7 @@ Contributor Documentation -------------------------------- * :ref:`system_overview` * :ref:`writing_converters` +* :ref:`writing_dynamo_aten_lowering_passes` * :ref:`useful_links` .. toctree:: @@ -137,6 +138,7 @@ Contributor Documentation contributors/system_overview contributors/writing_converters + contributors/writing_dynamo_aten_lowering_passes contributors/useful_links Indices diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index 32225e79fc..b271d0d6fb 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -6,8 +6,7 @@ import torch from torch._export import export -from torch_tensorrt.dynamo.backend.backends import constant_fold -from torch_tensorrt.dynamo.lowering import get_decompositions +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import set_log_level logger = logging.getLogger(__name__) @@ -29,6 +28,6 @@ def trace( "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) ): graph_module = export(model, tuple(inputs)).module() - constant_fold(graph_module) + graph_module = apply_lowering_passes(graph_module) logger.debug("Post export graph: " + str(graph_module.graph)) return graph_module diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 7fde8bbb41..022f3b193d 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -10,24 +10,12 @@ from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import _aot_export_function from torch._ops import OpOverload -from torch_tensorrt._utils import sanitized_torch_version from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level -from packaging import version - -# Modify import location of utilities based on Torch version -if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): - from torch._inductor.freezing import ConstantFolder, replace_node_with_constant -else: - from torch._inductor.constant_folding import ( - ConstantFolder, - replace_node_with_constant, - ) - logger = logging.getLogger(__name__) @@ -84,7 +72,7 @@ def _pretraced_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: # Invoke AOTAutograd to translate operators to aten - graph_module = aot_export_for_compile( + gm = aot_export_for_compile( gm, sample_inputs, decompositions=get_decompositions( @@ -94,10 +82,10 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - constant_fold(graph_module) + gm = apply_lowering_passes(gm) trt_compiled = compile_module( - graph_module, + gm, sample_inputs, settings=settings, ) @@ -121,35 +109,6 @@ def _pretraced_backend( raise -@torch.utils._python_dispatch._disable_current_modes() # type: ignore -def constant_fold(gm: torch.fx.GraphModule) -> Any: - """Adapted from: - https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 - - Folds constants in the graph module, not skipping constructors - - Modifies the graph in-place and replaces node with constants - """ - cf = ConstantFolder(gm, skip_constructors=False) - cf.run() - - for node, constant in cf.node_replacements.items(): - replace_node_with_constant(gm, node, constant) - - erased_params = [] - for node in gm.graph.nodes: - if node.op == "get_attr" and len(node.users) == 0: - delattr(gm, node.target) - erased_params.append(node) - - for node in erased_params: - gm.graph.erase_node(node) - - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - - def aot_export_for_compile( func: torch.fx.GraphModule, args: Sequence[torch.Tensor], diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 6eda61a6fd..34faa1d11b 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -2,4 +2,5 @@ from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 +from .passes import apply_lowering_passes from .substitutions import * # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py new file mode 100644 index 0000000000..ea393fab14 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -0,0 +1 @@ +from ._aten_lowering_pass import * diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py new file mode 100644 index 0000000000..a4c7fad607 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -0,0 +1,76 @@ +import logging +from typing import Callable, Optional + +import torch + +from .constant_folding import constant_fold +from .pass_manager import DynamoPassManager +from .repair_input_as_output import repair_input_as_output + +ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + constant_fold, + repair_input_as_output, + ] +) + +logger = logging.getLogger(__name__) + + +LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + + +def _aten_lowering_pass( + *args: LoweringPassSignature, + index: Optional[int] = None, +) -> LoweringPassSignature: + """Adds a lowering pass to the registry, at a specified index if desired + + If no index is specified, the lowering pass is inserted at the end of the list + """ + + def add_lowering_pass( + lowering_pass: LoweringPassSignature, + ) -> LoweringPassSignature: + ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) + logger.debug( + f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + ) + return lowering_pass + + # If there are arguments specified, the decorator may have been called as-is + if args: + # The decorator may only be called with the lowering pass + # The index must be specified as a keyword argument + if len(args) == 1 and callable(args[0]): + return add_lowering_pass(args[0]) + else: + raise AssertionError( + f"aten_lowering_pass decorator called with invalid arguments {args} " + "To specify an index to insert the pass, use the keyword 'index='" + ) + # If no arguments are specified, the decorator was called with an index keyword + else: + return add_lowering_pass + + +def _remove_lowering_pass(*, index: int) -> None: + """Removes a lowering pass at a specific index from the registry""" + ATEN_LOWERING_PASSES.remove_pass_with_index(index) + logger.debug( + f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" + ) + return + + +def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Applies the lowering passes to a graph module, returns the modified GraphModule""" + logging.debug( + f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" + ) + return ATEN_LOWERING_PASSES(gm) + + +def dump_lowering_passes() -> str: + """Returns a string containing the lowering passes""" + return str(ATEN_LOWERING_PASSES) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py new file mode 100644 index 0000000000..d17d0a2528 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -0,0 +1,56 @@ +import logging + +import torch +from torch_tensorrt._utils import sanitized_torch_version + +from packaging import version + +# Modify import location of utilities based on Torch version +if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): + from torch._inductor.freezing import ConstantFolder, replace_node_with_constant +else: + from torch._inductor.constant_folding import ( + ConstantFolder, + replace_node_with_constant, + ) + +logger = logging.getLogger(__name__) + + +@torch.utils._python_dispatch._disable_current_modes() # type: ignore +def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Adapted from: + https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 + + Folds constants in the graph module, not skipping constructors + + Modifies the graph in-place and replaces node with constants + """ + cf = ConstantFolder(gm, skip_constructors=False) + cf.run() + + for node, constant in cf.node_replacements.items(): + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.nodes: + # If get_attr node has no users, mark it for deletion + if node.op == "get_attr" and len(node.users) == 0: + # If the node's parameter is not a parameter of any other node, remove it + if not any( + other.target == node.target for other in gm.graph.nodes if other != node + ): + delattr(gm, node.target) + erased_params.append(node) + + # Remove unused nodes from the graph + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + logger.debug(f"Graph after constant folding:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py new file mode 100644 index 0000000000..51e2584364 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -0,0 +1,42 @@ +from typing import Any, Callable, List, Optional + +import torch +from torch.fx.passes.pass_manager import PassManager + + +class DynamoPassManager(PassManager): # type: ignore[misc] + def __init__( + self, + passes: Optional[ + List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]] + ] = None, + ): + super().__init__(passes) + + @classmethod + def build_from_passlist( + cls, + passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]], + ) -> Any: + pm = DynamoPassManager(passes) + return pm + + def add_pass_with_index( + self, + lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule], + index: Optional[int] = None, + ) -> None: + if index is None: + self.passes.append(lowering_pass) + index = -1 + else: + self.passes.insert(index, lowering_pass) + + def remove_pass_with_index(self, index: int) -> None: + del self.passes[index] + + def __call__(self, source: Any) -> Any: + return super().__call__(source) + + def __str__(self) -> str: + return str(self.passes) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py new file mode 100644 index 0000000000..6ce846637d --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py @@ -0,0 +1,53 @@ +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Repair scenarios where inputs are also outputs of the graph + + TRT does not allow such cases, so we insert a clone (identity) layer + """ + modified_graph = False + + # Extract graph placeholder Tensors + placeholders = [ + node + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.type, type) + and issubclass(node.type, torch.Tensor) + ) + ] + + for placeholder in placeholders: + # If any placeholder has any users which are direct graph outputs + if len(placeholder.users) >= 1 and any( + user.op == "output" for user in placeholder.users + ): + modified_graph = True + + # Get direct graph outputs which are direct uses of placeholders + direct_outputs = [user for user in placeholder.users if user.op == "output"] + + # Insert clone node for placeholder to ensure placeholder is not a direct output + with gm.graph.inserting_after(placeholder): + cloned_placeholder = gm.graph.call_function( + torch.ops.aten.clone.default, + args=(placeholder,), + ) + + # Replace placeholder as output with cloned version + for output in direct_outputs: + output.replace_input_with(placeholder, cloned_placeholder) + + if modified_graph: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") + + return gm diff --git a/setup.py b/setup.py index 6b013daf9e..d02adfb678 100644 --- a/setup.py +++ b/setup.py @@ -392,6 +392,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.substitutions", + "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", "torch_tensorrt.dynamo.runtime", "torch_tensorrt.dynamo.tools", @@ -419,6 +420,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions", + "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", "torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime", "torch_tensorrt.dynamo.tools": "py/torch_tensorrt/dynamo/tools", diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py new file mode 100644 index 0000000000..a63c5e3439 --- /dev/null +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -0,0 +1,95 @@ +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing + + +class TestInputAsOutput(TestCase): + def test_input_as_output(self): + class InputAsOutput(torch.nn.Module): + def forward(self, x, y): + y_new = y + x + 1 + y_new = y_new * 7 + return (y_new, x, y) + + inputs = [ + torch.rand( + 5, + 7, + ).cuda(), + torch.rand( + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(InputAsOutput()) + lower_graph_testing(fx_graph, inputs, min_block_size=1) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"InputAsOutput TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + +class TestLoweringPassMembership(TestCase): + def insert_at_end(self): + from torch_tensorrt.dynamo.lowering.passes import ( + ATEN_LOWERING_PASSES, + _aten_lowering_pass, + _remove_lowering_pass, + ) + + @_aten_lowering_pass + def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + return gm + + self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[-1]) + + _remove_lowering_pass(-1) + + self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes) + + def insert_at_index(self): + from torch_tensorrt.dynamo.lowering.passes import ( + ATEN_LOWERING_PASSES, + _aten_lowering_pass, + _remove_lowering_pass, + ) + + @_aten_lowering_pass(index=0) + def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + return gm + + self.assertEqual(identity_pass, ATEN_LOWERING_PASSES.passes[0]) + + _remove_lowering_pass(0) + + self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index f311f2db2b..af5336813f 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -6,8 +6,8 @@ import torch from torch._dynamo.utils import detect_fake_mode from torch_tensorrt.dynamo import partitioning -from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold -from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions +from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile +from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions DECIMALS_OF_AGREEMENT = 4 @@ -40,16 +40,16 @@ def fx_dynamo_testing_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: # Invoke AOTAutograd to translate operators to aten - graph_module = aot_export_for_compile( + gm = aot_export_for_compile( gm, sample_inputs, decompositions=get_decompositions(), ) - constant_fold(graph_module) + gm = apply_lowering_passes(gm) trt_compiled = custom_backend( - graph_module, + gm, sample_inputs, ) return trt_compiled