Skip to content

Commit a1bfda6

Browse files
committed
fix: Reorganize Dynamo directory + backends (#1928)
1 parent 8dcb030 commit a1bfda6

20 files changed

+103
-60
lines changed

.circleci/config.yml

+11-11
Original file line numberDiff line numberDiff line change
@@ -740,33 +740,33 @@ commands:
740740
- store_artifacts:
741741
path: /tmp/testlogs
742742

743-
test-dynamo-torch_compile-core:
744-
description: "Test the Dynamo torch_compile path"
743+
test-dynamo-compile-core:
744+
description: "Test the Dynamo compile path"
745745
steps:
746746
- run:
747-
name: Run Dynamo torch_compile core tests
747+
name: Run Dynamo compile core tests
748748
command: |
749-
cd py/torch_tensorrt/dynamo/torch_compile
749+
cd py/torch_tensorrt/dynamo/backend
750750
pushd test/
751-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
751+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml
752752
popd
753753
754754
- store_test_results:
755755
path: /tmp/artifacts
756756
- store_artifacts:
757757
path: /tmp/testlogs
758758

759-
test-dynamo-torch_compile:
760-
description: "Test the Dynamo torch_compile path"
759+
test-dynamo-compile:
760+
description: "Test the Dynamo compile path"
761761
steps:
762762
- run:
763-
name: Run Dynamo torch_compile E2E tests
763+
name: Run Dynamo compile E2E tests
764764
command: |
765765
cd py/torch_tensorrt/dynamo/
766766
pushd test/
767767
pip3 install timm
768768
pip3 install transformers
769-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
769+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile
770770
popd
771771
772772
- store_test_results:
@@ -1000,8 +1000,8 @@ jobs:
10001000
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
10011001
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
10021002
- dump-test-env
1003-
- test-dynamo-torch_compile
1004-
- test-dynamo-torch_compile-core
1003+
- test-dynamo-compile
1004+
- test-dynamo-compile-core
10051005
- test-dynamo-fx_ts
10061006

10071007
package-x86_64-linux:

py/torch_tensorrt/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _find_lib(name, paths):
9797

9898
if version.parse(torch.__version__) >= version.parse("2.dev"):
9999
from torch_tensorrt import dynamo
100-
from torch_tensorrt.dynamo import torch_compile
100+
from torch_tensorrt.dynamo import backend
101101

102102

103103
def _register_with_torch():

py/torch_tensorrt/_compile.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class _IRType(Enum):
1616
ts = 0
1717
fx = 1
1818
fx_ts_compat = 2
19-
torch_compile = 3
19+
dynamo_compile = 3
2020

2121

2222
class _ModuleType(Enum):
@@ -47,7 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4747

4848
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
4949
ir_targets_fx = ir == "fx"
50-
ir_targets_torch_compile = ir == "torch_compile"
50+
ir_targets_dynamo_compile = ir == "dynamo_compile"
5151
ir_targets_fx_ts_compat = ir == "fx_ts_compat"
5252

5353
if module_is_tsable and ir_targets_torchscript:
@@ -56,8 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
5656
return _IRType.fx
5757
elif module_is_fxable and ir_targets_fx_ts_compat:
5858
return _IRType.fx_ts_compat
59-
elif module_is_fxable and ir_targets_torch_compile:
60-
return _IRType.torch_compile
59+
elif module_is_fxable and ir_targets_dynamo_compile:
60+
return _IRType.dynamo_compile
6161
else:
6262
if ir == "default":
6363
# Options are listed in order of preference
@@ -156,8 +156,8 @@ def compile(
156156
dynamic_batch=False,
157157
**kwargs,
158158
)
159-
elif target_ir == _IRType.torch_compile:
160-
return torch_tensorrt.dynamo.torch_compile(
159+
elif target_ir == _IRType.dynamo_compile:
160+
return torch_tensorrt.dynamo.compile(
161161
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162162
)
163163
elif target_ir == _IRType.fx_ts_compat:

py/torch_tensorrt/dynamo/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from torch_tensorrt.dynamo import fx_ts_compat
2-
from .torch_compile import compile as torch_compile
2+
from .backend import compile

py/torch_tensorrt/dynamo/torch_compile/__init__.py renamed to py/torch_tensorrt/dynamo/backend/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
12-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device
13-
from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend
14-
from torch_tensorrt.dynamo.torch_compile._defaults import (
11+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
12+
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
13+
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
14+
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
@@ -121,6 +121,6 @@ def create_backend(
121121
)
122122

123123
return partial(
124-
tensorrt_backend,
124+
torch_tensorrt_backend,
125125
settings=settings,
126126
)

py/torch_tensorrt/dynamo/torch_compile/_defaults.py renamed to py/torch_tensorrt/dynamo/backend/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
PRECISION = LowerPrecision.FP32
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
7-
MAX_NUM_TRT_ENGINES = 200
7+
MAX_NUM_TRT_ENGINES = 10

py/torch_tensorrt/dynamo/torch_compile/_settings.py renamed to py/torch_tensorrt/dynamo/backend/_settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22

33
from torch_tensorrt.fx.utils import LowerPrecision
4-
from torch_tensorrt.dynamo.torch_compile._defaults import (
4+
from torch_tensorrt.dynamo.backend._defaults import (
55
PRECISION,
66
DEBUG,
77
MAX_WORKSPACE_SIZE,

py/torch_tensorrt/dynamo/torch_compile/backends.py renamed to py/torch_tensorrt/dynamo/backend/backends.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,42 @@
44
from functools import partial
55
import torch._dynamo as td
66

7-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
8-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
7+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
11+
from torch_tensorrt.dynamo.backend.lowering._partition import (
1212
partition,
1313
get_submod_inputs,
1414
)
15-
from torch_tensorrt.dynamo.torch_compile.conversion import convert_module
15+
from torch_tensorrt.dynamo.backend.conversion import convert_module
1616

1717
from torch._dynamo.backends.common import fake_tensor_unsupported
1818

1919
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2020

2121

22-
@td.register_backend(name="tensorrt")
22+
@td.register_backend(name="torch_tensorrt")
2323
@fake_tensor_unsupported
24-
def tensorrt_backend(
25-
gm: torch.nn.Module,
24+
def torch_tensorrt_backend(
25+
gm: torch.fx.GraphModule,
26+
sample_inputs: Sequence[torch.Tensor],
27+
settings: CompilationSettings = CompilationSettings(),
28+
):
29+
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
30+
31+
return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)
32+
33+
34+
@td.register_backend(name="aot_torch_tensorrt_aten")
35+
@fake_tensor_unsupported
36+
def aot_torch_tensorrt_aten_backend(
37+
gm: torch.fx.GraphModule,
2638
sample_inputs: Sequence[torch.Tensor],
2739
settings: CompilationSettings = CompilationSettings(),
2840
):
2941
custom_backend = partial(
30-
fx_dynamo_backend,
42+
_pretraced_backend,
3143
settings=settings,
3244
)
3345

@@ -40,14 +52,12 @@ def tensorrt_backend(
4052
)
4153

4254

43-
@td.register_backend(name="fx_tensorrt")
44-
@fake_tensor_unsupported
45-
def fx_dynamo_backend(
55+
def _pretraced_backend(
4656
gm: torch.fx.GraphModule,
47-
example_inputs: Sequence[torch.Tensor],
57+
sample_inputs: Sequence[torch.Tensor],
4858
settings: CompilationSettings = CompilationSettings(),
4959
):
50-
"""Helper function to manage translation of FX module to TRT engines
60+
"""Helper function to manage translation of traced FX module to TRT engines
5161
5262
Args:
5363
module: FX GraphModule to convert
@@ -57,9 +67,9 @@ def fx_dynamo_backend(
5767
Compiled FX GraphModule
5868
"""
5969
try:
60-
trt_compiled = compile_module(
70+
trt_compiled = _compile_module(
6171
gm,
62-
example_inputs,
72+
sample_inputs,
6373
settings=settings,
6474
)
6575
return trt_compiled
@@ -72,12 +82,12 @@ def fx_dynamo_backend(
7282
return gm.forward
7383

7484

75-
def compile_module(
85+
def _compile_module(
7686
gm: torch.fx.GraphModule,
77-
example_inputs: Sequence[torch.Tensor],
87+
sample_inputs: Sequence[torch.Tensor],
7888
settings: CompilationSettings = CompilationSettings(),
7989
) -> torch.fx.GraphModule:
80-
"""Compile an FX module
90+
"""Compile a traced FX module
8191
8292
Includes: Partitioning + Conversion Phases
8393
@@ -100,7 +110,7 @@ def compile_module(
100110

101111
# Get submodule inputs
102112
submodule_inputs = get_submod_inputs(
103-
partitioned_module, submodule, example_inputs
113+
partitioned_module, submodule, sample_inputs
104114
)
105115

106116
# Create TRT Module from submodule
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
2+
get_decompositions,
3+
)
4+
from torch_tensorrt.dynamo.backend.lowering._partition import (
5+
partition,
6+
get_submod_inputs,
7+
)

py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py renamed to py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES
5+
from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
77
from torch.fx.passes.operator_support import OperatorSupport
88

py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py renamed to py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
1+
from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs
22
from utils import same_output_format
33
import torch_tensorrt
44
import unittest

py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py renamed to py/torch_tensorrt/dynamo/backend/test/test_partitioning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.torch_compile.lowering import partition
1+
from torch_tensorrt.dynamo.backend.lowering import partition
22
from torch.testing._internal.common_utils import run_tests, TestCase
33
import torch
44
from copy import deepcopy

py/torch_tensorrt/dynamo/torch_compile/test/utils.py renamed to py/torch_tensorrt/dynamo/backend/test/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from functools import partial
33
from typing import List, Sequence
44
import torch
5-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
5+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
66
get_decompositions,
77
)
8-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
8+
from torch_tensorrt.dynamo.backend.lowering._partition import (
99
partition,
1010
)
1111

py/torch_tensorrt/dynamo/torch_compile/utils.py renamed to py/torch_tensorrt/dynamo/backend/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def prepare_inputs(
4545

4646
else:
4747
raise ValueError(
48-
f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. "
48+
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
4949
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
5050
)
5151

py/torch_tensorrt/dynamo/test/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def pytest_addoption(parser):
99
type=str,
1010
required=True,
1111
help="IR to compile with",
12-
choices=["torch_compile", "fx_ts_compat"],
12+
choices=["dynamo_compile", "fx_ts_compat"],
1313
)
1414

1515

0 commit comments

Comments
 (0)