-
Notifications
You must be signed in to change notification settings - Fork 365
feat: Add ATen lowering pass system #2280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
109 changes: 109 additions & 0 deletions
109
docsrc/contributors/writing_dynamo_aten_lowering_passes.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
.. _writing_dynamo_aten_lowering_passes: | ||
|
||
Writing Dynamo ATen Lowering Passes | ||
=================== | ||
|
||
Basics of a Lowering Pass | ||
------------ | ||
|
||
ATen lowering passes are Python functions which take as input a graph of ATen operators, apply some desired modification such as operator coalescing/fusion, operator replacement, subgraph rewriting, custom operator insertion, or other operation on a `torch.fx.GraphModule`, then return the modified graph to the caller. These lowering passes generally modify the graph in-place and return the same input object. | ||
|
||
Lowering Pass Requirements | ||
------------ | ||
|
||
An ATen lowering pass function in Torch-TRT must satisfy two requirements: | ||
- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule` | ||
- The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation | ||
|
||
See this link for information on `Graph Manipulations <https://pytorch.org/docs/stable/fx.html#graph-manipulation>`_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines. | ||
|
||
Example Lowering Pass | ||
------------ | ||
|
||
.. code-block:: python | ||
|
||
def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
"""Repair scenarios where inputs are also outputs of the graph | ||
|
||
TRT does not allow such cases, so we insert a clone (identity) layer | ||
""" | ||
modified_graph = False | ||
|
||
# Extract graph placeholder Tensors | ||
placeholders = [ | ||
node | ||
for node in gm.graph.nodes | ||
if ( | ||
node.op == "placeholder" | ||
and isinstance(node.type, type) | ||
and issubclass(node.type, torch.Tensor) | ||
) | ||
] | ||
|
||
for placeholder in placeholders: | ||
# If any placeholder has any users which are direct graph outputs | ||
if len(placeholder.users) >= 1 and any( | ||
user.op == "output" for user in placeholder.users | ||
): | ||
modified_graph = True | ||
|
||
# Get direct graph outputs which are direct uses of placeholders | ||
direct_outputs = [user for user in placeholder.users if user.op == "output"] | ||
|
||
# Insert clone node for placeholder to ensure | ||
# placeholder is not a direct output | ||
with gm.graph.inserting_after(placeholder): | ||
cloned_placeholder = gm.graph.call_function( | ||
torch.ops.aten.clone.default, | ||
args=(placeholder,), | ||
) | ||
|
||
# Replace placeholder as output with cloned version | ||
for output in direct_outputs: | ||
output.replace_input_with(placeholder, cloned_placeholder) | ||
|
||
# If the graph was modified, clean up the graph and ensure it is up-to-date | ||
if modified_graph: | ||
gm.graph.eliminate_dead_code() | ||
gm.graph.lint() | ||
gm.recompile() | ||
logger.debug(f"Graph after repair_input_as_output:\n{gm.graph}") | ||
|
||
return gm | ||
|
||
|
||
Registering Lowering Passes | ||
---------------------- | ||
|
||
Lowering passes are currently registered in `py/torch_tensorrt/dynamo/lowering/passes/__init__.py`, using the `torch.fx.passes.pass_manager.PassManager` utility to assemble the list of passes in a desired order. New passes added directly to that list will be applied to graphs in the Torch-TensorRT `torch.compile` backend. Currently, we offer an ATen lowering pass registration decorator for convenience, which can be invoked either directly, or with the optional `index` keyword argument which controls where in the pass list the lowering pass will be inserted. | ||
|
||
For instance, to insert the pass at the default location (end of the list), the following code can be used: | ||
|
||
.. code-block:: python | ||
|
||
@_aten_lowering_pass | ||
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
... | ||
|
||
Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used: | ||
|
||
.. code-block:: python | ||
|
||
@_aten_lowering_pass(index=0) | ||
def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
... | ||
|
||
There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index. | ||
|
||
.. code-block:: python | ||
|
||
# Print all lowering passes in the list | ||
print(dump_lowering_passes()) | ||
|
||
# Apply lowering passes to a GraphModule | ||
apply_lowering_passes(graph_module) | ||
|
||
# Remove the lowering pass at index 1 | ||
_remove_lowering_pass(index=1) | ||
|
||
**Note:** The above APIs are subject to change, as the lowering pass system evolves. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._aten_lowering_pass import * |
76 changes: 76 additions & 0 deletions
76
py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import logging | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
|
||
from .constant_folding import constant_fold | ||
from .pass_manager import DynamoPassManager | ||
from .repair_input_as_output import repair_input_as_output | ||
|
||
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( | ||
[ | ||
constant_fold, | ||
repair_input_as_output, | ||
] | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule] | ||
|
||
|
||
def _aten_lowering_pass( | ||
*args: LoweringPassSignature, | ||
index: Optional[int] = None, | ||
) -> LoweringPassSignature: | ||
"""Adds a lowering pass to the registry, at a specified index if desired | ||
|
||
If no index is specified, the lowering pass is inserted at the end of the list | ||
""" | ||
|
||
def add_lowering_pass( | ||
lowering_pass: LoweringPassSignature, | ||
) -> LoweringPassSignature: | ||
ATEN_LOWERING_PASSES.add_pass_with_index(lowering_pass, index) | ||
logger.debug( | ||
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_LOWERING_PASSES}" | ||
) | ||
return lowering_pass | ||
|
||
# If there are arguments specified, the decorator may have been called as-is | ||
if args: | ||
# The decorator may only be called with the lowering pass | ||
# The index must be specified as a keyword argument | ||
if len(args) == 1 and callable(args[0]): | ||
return add_lowering_pass(args[0]) | ||
else: | ||
raise AssertionError( | ||
f"aten_lowering_pass decorator called with invalid arguments {args} " | ||
"To specify an index to insert the pass, use the keyword 'index='" | ||
) | ||
# If no arguments are specified, the decorator was called with an index keyword | ||
else: | ||
return add_lowering_pass | ||
|
||
|
||
def _remove_lowering_pass(*, index: int) -> None: | ||
"""Removes a lowering pass at a specific index from the registry""" | ||
ATEN_LOWERING_PASSES.remove_pass_with_index(index) | ||
logger.debug( | ||
f"Removed lowering pass at index {index}, current passlist: {ATEN_LOWERING_PASSES}" | ||
) | ||
return | ||
|
||
|
||
def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
"""Applies the lowering passes to a graph module, returns the modified GraphModule""" | ||
logging.debug( | ||
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" | ||
) | ||
return ATEN_LOWERING_PASSES(gm) | ||
|
||
|
||
def dump_lowering_passes() -> str: | ||
"""Returns a string containing the lowering passes""" | ||
return str(ATEN_LOWERING_PASSES) |
56 changes: 56 additions & 0 deletions
56
py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import logging | ||
|
||
import torch | ||
from torch_tensorrt._utils import sanitized_torch_version | ||
|
||
from packaging import version | ||
|
||
# Modify import location of utilities based on Torch version | ||
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): | ||
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant | ||
else: | ||
from torch._inductor.constant_folding import ( | ||
ConstantFolder, | ||
replace_node_with_constant, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@torch.utils._python_dispatch._disable_current_modes() # type: ignore | ||
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
"""Adapted from: | ||
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 | ||
|
||
Folds constants in the graph module, not skipping constructors | ||
|
||
Modifies the graph in-place and replaces node with constants | ||
""" | ||
cf = ConstantFolder(gm, skip_constructors=False) | ||
cf.run() | ||
|
||
for node, constant in cf.node_replacements.items(): | ||
replace_node_with_constant(gm, node, constant) | ||
|
||
erased_params = [] | ||
for node in gm.graph.nodes: | ||
# If get_attr node has no users, mark it for deletion | ||
if node.op == "get_attr" and len(node.users) == 0: | ||
# If the node's parameter is not a parameter of any other node, remove it | ||
if not any( | ||
other.target == node.target for other in gm.graph.nodes if other != node | ||
): | ||
delattr(gm, node.target) | ||
erased_params.append(node) | ||
|
||
# Remove unused nodes from the graph | ||
for node in erased_params: | ||
gm.graph.erase_node(node) | ||
|
||
gm.graph.eliminate_dead_code() | ||
gm.graph.lint() | ||
gm.recompile() | ||
|
||
logger.debug(f"Graph after constant folding:\n{gm.graph}") | ||
|
||
return gm |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.