Skip to content

Commit a43552b

Browse files
committed
feat: Add maxpool lowering pass indices
- Add lowering pass to switch `indices` variants to non-`indices` variants - Add testing for lowering passes - Remove unused directory
1 parent 8ebf24d commit a43552b

File tree

13 files changed

+249
-20
lines changed

13 files changed

+249
-20
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
rev: 'v1.4.1'
4141
hooks:
4242
- id: mypy
43-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
43+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
# Ruff version.
4646
rev: v0.0.278

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py renamed to py/torch_tensorrt/dynamo/_experimental/einsum.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import torch
44
import torch._custom_ops as library
55
from torch.fx.node import Argument, Target
6-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
76
from torch_tensorrt.fx.converter_registry import tensorrt_converter
87
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
98
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
109

10+
from ._pre_aot_lowering import register_substitution
11+
1112
library.custom_op(
1213
"tensorrt::einsum",
1314
"(str equation, Tensor[] tensors) -> Tensor",

py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py renamed to py/torch_tensorrt/dynamo/_experimental/maxpool1d.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import torch
44
import torch._custom_ops as library
55
from torch.fx.node import Argument, Target
6-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
76
from torch_tensorrt.fx.converter_registry import tensorrt_converter
87
from torch_tensorrt.fx.converters import acc_ops_converters
98
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
109

10+
from ._pre_aot_lowering import register_substitution
11+
1112
# This file serves as an example and a tutorial for excluding custom modules from
1213
# torch.compile tracing. Each required step is labeled with a number indicating the
1314
# preferable implementation order.

tests/py/dynamo/backend/test_pre_aot_lowering.py renamed to py/torch_tensorrt/dynamo/_experimental/test_pre_aot_lowering.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt
33
from torch.testing._internal.common_utils import TestCase, run_tests
44

5-
from ..testing_utilities import lower_graph_testing
5+
from .....tests.py.dynamo.testing_utilities import lower_graph_testing
66

77

88
class TestMaxPool1D(TestCase):
@@ -52,7 +52,7 @@ def forward(self, x):
5252

5353
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
5454
self.assertAlmostEqual(
55-
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model."
55+
max_diff, 0, "Maxpool1d TRT outputs don't match with the original model."
5656
)
5757

5858

@@ -102,7 +102,7 @@ def forward(self, x, y):
102102

103103
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
104104
self.assertAlmostEqual(
105-
max_diff, 0, f"Einsum TRT outputs don't match with the original model."
105+
max_diff, 0, "Einsum TRT outputs don't match with the original model."
106106
)
107107

108108

py/torch_tensorrt/dynamo/backend/backends.py

-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
get_decompositions,
1616
repair_input_aliasing,
1717
)
18-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1918
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs, set_log_level
2019

2120
logger = logging.getLogger(__name__)
@@ -64,9 +63,6 @@ def _pretraced_backend(
6463
try:
6564
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))
6665

67-
# Perform Pre-AOT Lowering for Module-Level Replacement
68-
gm = pre_aot_substitutions(gm)
69-
7066
fake_mode = detect_fake_mode(sample_inputs)
7167

7268
# Place backend tracing within FakeTensor context allowing nonfake Tensors
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
from ._decompositions import get_decompositions # noqa: F401
22
from ._fusers import * # noqa: F401
3-
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
4-
from ._pre_aot_lowering import register_substitution # noqa: F401
53
from ._repair_input_aliasing import repair_input_aliasing
64
from .passes import apply_lowering_passes
7-
from .substitutions import * # noqa: F401

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .pass_manager import DynamoPassManager
1010
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1111
from .repair_input_as_output import repair_input_as_output
12+
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1213

1314
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1415
[
@@ -17,6 +18,7 @@
1718
repair_input_as_output,
1819
lower_efficient_attention,
1920
fuse_prims_broadcast,
21+
replace_max_pool_with_indices,
2022
]
2123
)
2224

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import logging
2+
import operator
3+
from typing import Sequence
4+
5+
import torch
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
)
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def replace_max_pool_with_indices(
14+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
15+
) -> torch.fx.GraphModule:
16+
"""Replace MaxPool nodes which return unused indices"""
17+
replacement_dict = {
18+
torch.ops.aten.max_pool1d_with_indices.default: torch.ops.aten.max_pool1d.default,
19+
torch.ops.aten.max_pool2d_with_indices.default: torch.ops.aten.max_pool2d.default,
20+
torch.ops.aten.max_pool3d_with_indices.default: torch.ops.aten.max_pool3d.default,
21+
}
22+
23+
modified_graph = False
24+
25+
for node in gm.graph.nodes:
26+
# If the node is a placeholder and its only user is a clone node
27+
# it was modified by the input alias-fixing pass, and the change
28+
# needs to be undone
29+
if (
30+
node.target in replacement_dict
31+
and len(node.users) == 1
32+
and list(node.users)[0].target == operator.getitem
33+
and list(node.users)[0].args[1] == 0
34+
):
35+
modified_graph = True
36+
37+
# Replace all uses of the clone with the placholder, delete the clone
38+
getitem_node = list(node.users)[0]
39+
40+
with gm.graph.inserting_after(getitem_node):
41+
maxpool_fused = gm.graph.call_function(
42+
replacement_dict[node.target],
43+
args=node.args,
44+
kwargs=node.kwargs,
45+
)
46+
47+
logger.debug(
48+
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "
49+
f"is the only user of placeholder {node} and was inserted by the compiler."
50+
)
51+
52+
getitem_node.replace_all_uses_with(maxpool_fused)
53+
gm.graph.erase_node(getitem_node)
54+
gm.graph.erase_node(node)
55+
56+
if modified_graph:
57+
gm = clean_up_graph_after_modifications(gm)
58+
logger.debug(f"Graph after fusing maxpool operators with indices:\n{gm.graph}")
59+
60+
return gm

py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py

-2
This file was deleted.

setup.py

-2
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def run(self):
391391
"torch_tensorrt.dynamo.conversion.impl.slice",
392392
"torch_tensorrt.dynamo.conversion.impl.unary",
393393
"torch_tensorrt.dynamo.lowering",
394-
"torch_tensorrt.dynamo.lowering.substitutions",
395394
"torch_tensorrt.dynamo.lowering.passes",
396395
"torch_tensorrt.dynamo.partitioning",
397396
"torch_tensorrt.dynamo.runtime",
@@ -419,7 +418,6 @@ def run(self):
419418
"torch_tensorrt.dynamo.conversion.impl.slice": "py/torch_tensorrt/dynamo/conversion/impl/slice",
420419
"torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary",
421420
"torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",
422-
"torch_tensorrt.dynamo.lowering.substitutions": "py/torch_tensorrt/dynamo/lowering/substitutions",
423421
"torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes",
424422
"torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning",
425423
"torch_tensorrt.dynamo.runtime": "py/torch_tensorrt/dynamo/runtime",

tests/py/dynamo/lowering/test_decompositions.py

+179
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,185 @@ def forward(self, x):
313313
f"Var TRT outputs don't match with the original model.",
314314
)
315315

316+
def test_lowering_maxpool1d_functional(self):
317+
class MaxPool1d(torch.nn.Module):
318+
def forward(self, x):
319+
y = torch.nn.functional.max_pool1d(x, 3)
320+
return y
321+
322+
# Operations expected to be removed in the traced graph after decompositions
323+
expected_ops = {torch.ops.aten.max_pool2d.default}
324+
unexpected_ops = {
325+
torch.ops.aten.max_pool1d_with_indices.default,
326+
torch.ops.aten.max_pool2d_with_indices.default,
327+
}
328+
329+
inputs = [torch.randn(4, 8, 27).cuda()]
330+
331+
fx_graph = torch.fx.symbolic_trace(MaxPool1d())
332+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
333+
fx_graph,
334+
inputs,
335+
expected_ops=expected_ops,
336+
unexpected_ops=unexpected_ops,
337+
min_block_size=1,
338+
)
339+
340+
self.assertEquals(
341+
len(unexpected_ops_seen),
342+
0,
343+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
344+
)
345+
346+
self.assertEquals(
347+
len(expected_ops_unseen),
348+
0,
349+
f"The following expected ops were not encountered: {expected_ops_unseen}",
350+
)
351+
352+
torch._dynamo.reset()
353+
354+
# Validate that the results between Torch and Torch-TRT are similar
355+
optimized_model = torch_tensorrt.compile(
356+
fx_graph,
357+
"torch_compile",
358+
inputs,
359+
min_block_size=1,
360+
pass_through_build_failures=True,
361+
)
362+
optimized_model_results = optimized_model(*inputs).detach().cpu()
363+
torch_model_results = fx_graph(*inputs).detach().cpu()
364+
365+
max_diff = float(
366+
torch.max(torch.abs(optimized_model_results - torch_model_results))
367+
)
368+
self.assertAlmostEqual(
369+
max_diff,
370+
0,
371+
DECIMALS_OF_AGREEMENT,
372+
f"MaxPool1d TRT outputs don't match with the original model.",
373+
)
374+
375+
def test_lowering_maxpool_2d_module(self):
376+
class MaxPool2d(torch.nn.Module):
377+
def __init__(self, *args, **kwargs) -> None:
378+
super().__init__(*args, **kwargs)
379+
self.maxpool = torch.nn.MaxPool2d((5, 3), stride=(2, 1))
380+
381+
def forward(self, x):
382+
y = self.maxpool(x)
383+
return y
384+
385+
# Operations expected to be removed in the traced graph after decompositions
386+
expected_ops = {torch.ops.aten.max_pool2d.default}
387+
unexpected_ops = {torch.ops.aten.max_pool2d_with_indices.default}
388+
389+
inputs = [torch.randn(1, 3, 25, 30).cuda()]
390+
391+
fx_graph = torch.fx.symbolic_trace(MaxPool2d())
392+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
393+
fx_graph,
394+
inputs,
395+
expected_ops=expected_ops,
396+
unexpected_ops=unexpected_ops,
397+
min_block_size=1,
398+
)
399+
400+
self.assertEquals(
401+
len(unexpected_ops_seen),
402+
0,
403+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
404+
)
405+
406+
self.assertEquals(
407+
len(expected_ops_unseen),
408+
0,
409+
f"The following expected ops were not encountered: {expected_ops_unseen}",
410+
)
411+
412+
torch._dynamo.reset()
413+
414+
# Validate that the results between Torch and Torch-TRT are similar
415+
optimized_model = torch_tensorrt.compile(
416+
fx_graph,
417+
"torch_compile",
418+
inputs,
419+
min_block_size=1,
420+
pass_through_build_failures=True,
421+
)
422+
optimized_model_results = optimized_model(*inputs).detach().cpu()
423+
torch_model_results = fx_graph(*inputs).detach().cpu()
424+
425+
max_diff = float(
426+
torch.max(torch.abs(optimized_model_results - torch_model_results))
427+
)
428+
self.assertAlmostEqual(
429+
max_diff,
430+
0,
431+
DECIMALS_OF_AGREEMENT,
432+
f"MaxPool2d TRT outputs don't match with the original model.",
433+
)
434+
435+
def test_lowering_maxpool_3d_module(self):
436+
class MaxPool3d(torch.nn.Module):
437+
def __init__(self, *args, **kwargs) -> None:
438+
super().__init__(*args, **kwargs)
439+
self.maxpool = torch.nn.MaxPool3d(3)
440+
441+
def forward(self, x):
442+
y = self.maxpool(x)
443+
return y
444+
445+
# Operations expected to be removed in the traced graph after decompositions
446+
expected_ops = {torch.ops.aten.max_pool3d.default}
447+
unexpected_ops = {torch.ops.aten.max_pool3d_with_indices.default}
448+
449+
inputs = [torch.randn(4, 8, 27, 72, 96).cuda()]
450+
451+
fx_graph = torch.fx.symbolic_trace(MaxPool3d())
452+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
453+
fx_graph,
454+
inputs,
455+
expected_ops=expected_ops,
456+
unexpected_ops=unexpected_ops,
457+
min_block_size=1,
458+
)
459+
460+
self.assertEquals(
461+
len(unexpected_ops_seen),
462+
0,
463+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
464+
)
465+
466+
self.assertEquals(
467+
len(expected_ops_unseen),
468+
0,
469+
f"The following expected ops were not encountered: {expected_ops_unseen}",
470+
)
471+
472+
torch._dynamo.reset()
473+
474+
# Validate that the results between Torch and Torch-TRT are similar
475+
optimized_model = torch_tensorrt.compile(
476+
fx_graph,
477+
"torch_compile",
478+
inputs,
479+
min_block_size=1,
480+
pass_through_build_failures=True,
481+
)
482+
optimized_model_results = optimized_model(*inputs).detach().cpu()
483+
torch_model_results = fx_graph(*inputs).detach().cpu()
484+
485+
max_diff = float(
486+
torch.max(torch.abs(optimized_model_results - torch_model_results))
487+
)
488+
self.assertAlmostEqual(
489+
max_diff,
490+
0,
491+
DECIMALS_OF_AGREEMENT,
492+
f"MaxPool3d TRT outputs don't match with the original model.",
493+
)
494+
316495

317496
if __name__ == "__main__":
318497
run_tests()

tests/py/dynamo/testing_utilities.py

-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_decompositions,
1313
repair_input_aliasing,
1414
)
15-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1615

1716
DECIMALS_OF_AGREEMENT = 4
1817

@@ -35,8 +34,6 @@ def fx_dynamo_testing_backend(
3534
use_fast_partitioner=use_fast_partitioner,
3635
)
3736

38-
gm = pre_aot_substitutions(gm)
39-
4037
fake_mode = detect_fake_mode(sample_inputs)
4138

4239
# Place backend tracing within FakeTensor context allowing nonfake Tensors

0 commit comments

Comments
 (0)