Skip to content

Commit 3a034e1

Browse files
committed
chore: ready to start review
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6f802d1 commit 3a034e1

File tree

101 files changed

+522
-587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+522
-587
lines changed

.pre-commit-config.yaml

+14-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ repos:
55
hooks:
66
- id: check-yaml
77
- id: trailing-whitespace
8+
exclude: ^docs
89
- id: check-added-large-files
910
args:
1011
- --maxkb=1000
@@ -13,11 +14,7 @@ repos:
1314
- id: mixed-line-ending
1415
args:
1516
- --fix=lf
16-
- repo: https://github.com/psf/black
17-
rev: 23.7.0
18-
hooks:
19-
- id: black
20-
exclude: ^examples/custom_converters/elu_converter/setup.py
17+
exclude: ^docs
2118
- repo: https://github.com/pre-commit/mirrors-clang-format
2219
rev: v16.0.6
2320
hooks:
@@ -30,21 +27,26 @@ repos:
3027
args:
3128
- --warnings=all
3229
- id: buildifier-lint
33-
- repo: https://github.com/astral-sh/ruff-pre-commit
34-
# Ruff version.
35-
rev: v0.0.278
36-
hooks:
37-
- id: ruff
3830
- repo: https://github.com/abravalheri/validate-pyproject
3931
rev: v0.13
4032
hooks:
4133
- id: validate-pyproject
34+
python_version: "3.11"
4235
- repo: https://github.com/pre-commit/mirrors-mypy
4336
rev: 'v1.4.1'
4437
hooks:
4538
- id: mypy
46-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools"
47-
python_version: "3.11"
39+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
40+
- repo: https://github.com/astral-sh/ruff-pre-commit
41+
# Ruff version.
42+
rev: v0.0.278
43+
hooks:
44+
- id: ruff
45+
- repo: https://github.com/psf/black
46+
rev: 23.7.0
47+
hooks:
48+
- id: black
49+
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
4850
- repo: local
4951
hooks:
5052
- id: dont-commit-upstream

core/conversion/evaluators/aten.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,12 @@ DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
103103
"aten::pow.float_int(float a, int b) -> (float)",
104104
}));
105105

106-
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
106+
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
107+
and,
108+
"aten::__and__",
109+
a&& b,
110+
bool,
111+
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
107112
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
108113
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
109114
xor,

docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# %%
1818

19+
1920
# We begin by defining a model
2021
class Model(torch.nn.Module):
2122
def __init__(self) -> None:

docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# %%
1818

19+
1920
# We begin by defining a model
2021
class Model(torch.nn.Module):
2122
def __init__(self) -> None:

py/torch_tensorrt/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def _find_lib(name: str, paths: List[str]) -> str:
8282

8383
import torch
8484
from torch_tensorrt._compile import * # noqa: F403
85+
from torch_tensorrt._Device import Device # noqa: F401
8586
from torch_tensorrt._enums import * # noqa: F403
86-
from torch_tensorrt._util import * # noqa: F403
87-
from torch_tensorrt._Input import Input # noqa: F401
88-
from torch_tensorrt._Device import Device # noqa: F401
87+
from torch_tensorrt._Input import Input # noqa: F401
88+
from torch_tensorrt._utils import * # noqa: F403
8989

9090
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9191
from torch_tensorrt import dynamo # noqa: F401

py/torch_tensorrt/_compile.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, Callable, List, Optional, Set, TypeGuard
2+
from typing import Any, Callable, List, Optional, Set, TypeGuard, Sequence
33

44
import torch
55
import torch.fx
@@ -15,13 +15,13 @@
1515

1616

1717
def _non_fx_input_interface(
18-
inputs: List[Input | torch.Tensor | InputTensorSpec],
18+
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
1919
) -> TypeGuard[List[Input | torch.Tensor]]:
2020
return all(isinstance(i, torch.Tensor | Input) for i in inputs)
2121

2222

2323
def _fx_input_interface(
24-
inputs: List[Input | torch.Tensor | InputTensorSpec],
24+
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
2525
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
2626
return all(isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs)
2727

@@ -97,7 +97,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
9797
def compile(
9898
module: Any,
9999
ir: str = "default",
100-
inputs: Optional[List[Input | torch.Tensor | InputTensorSpec]] = None,
100+
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
101101
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
102102
**kwargs: Any,
103103
) -> (
@@ -201,7 +201,7 @@ def compile(
201201
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
202202

203203

204-
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
204+
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
205205
"""
206206
Returns a boxed model which is the output of torch.compile.
207207
This does not compile the model to TRT. Execute this model on
@@ -216,8 +216,8 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
216216

217217
def convert_method_to_trt_engine(
218218
module: Any,
219-
inputs: List[Input | torch.Tensor],
220-
method_name: str,
219+
method_name: str = "forward",
220+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
221221
ir: str = "default",
222222
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
223223
**kwargs: Any,

py/torch_tensorrt/_enums.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from torch_tensorrt._C import dtype, EngineCapability, TensorFormat # noqa: F401
1+
from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401
2+
23
from tensorrt import DeviceType # noqa: F401

py/torch_tensorrt/_util.py renamed to py/torch_tensorrt/_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from typing import Any
2+
13
import torch
2-
from torch_tensorrt import _C, __version__
4+
from torch_tensorrt._version import __version__
5+
from torch_tensorrt import _C
36

47

58
def dump_build_info() -> None:
@@ -30,7 +33,7 @@ def set_device(gpu_id: int) -> None:
3033
_C.set_device(gpu_id)
3134

3235

33-
def sanitized_torch_version() -> str:
36+
def sanitized_torch_version() -> Any:
3437
return (
3538
torch.__version__
3639
if ".nv" not in torch.__version__

py/torch_tensorrt/dynamo/_SourceIR.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class SourceIR(Enum):
99
TORCHTRT_LOWERED = auto()
1010
UNKNOWN = auto()
1111

12-
def __str__(self):
12+
def __str__(self) -> str:
1313
if self == SourceIR.NN:
1414
return "nn"
1515
elif self == SourceIR.ACC:

py/torch_tensorrt/dynamo/__init__.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from packaging import version
2-
from torch_tensorrt._util import sanitized_torch_version
2+
from torch_tensorrt._utils import sanitized_torch_version
33

44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
55
from ._settings import * # noqa: F403
6-
from .conversion import * # noqa: F403
7-
from .aten_tracer import trace # noqa: F403
8-
from .conversion.converter_registry import (
9-
DYNAMO_CONVERTERS, # noqa: F403
10-
dynamo_tensorrt_converter, # noqa: F403
11-
)
12-
from .compile import compile # noqa: F403
13-
from ._SourceIR import SourceIR # noqa: F403
6+
from ._SourceIR import SourceIR # noqa: F403
7+
from .aten_tracer import trace # noqa: F403
8+
from .compile import compile # noqa: F403
9+
from .conversion import * # noqa: F403
10+
from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403
11+
from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403

py/torch_tensorrt/dynamo/_settings.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@
99
OPTIMIZATION_LEVEL,
1010
PASS_THROUGH_BUILD_FAILURES,
1111
PRECISION,
12-
USE_PYTHON_RUNTIME,
13-
<<<<<<< HEAD
1412
TRUNCATE_LONG_AND_DOUBLE,
15-
=======
13+
USE_PYTHON_RUNTIME,
1614
VERSION_COMPATIBLE,
1715
WORKSPACE_SIZE,
18-
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
1916
)
2017

2118

py/torch_tensorrt/dynamo/aten_tracer.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import copy
22
import sys
33
from contextlib import contextmanager
4-
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
4+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
55

66
import torch
77
import torch._dynamo as torchdynamo
8-
from torch import _guards
98
from torch.fx.passes.infra.pass_base import PassResult
109
from torch_tensorrt.dynamo.utils import req_torch_version
1110
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
@@ -23,11 +22,7 @@
2322
)
2423
from typing_extensions import TypeAlias
2524

26-
Value: TypeAlias = Union[
27-
Tuple["Value", ...],
28-
List["Value"],
29-
Dict[str, "Value"],
30-
]
25+
Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"]
3126

3227

3328
class DynamoConfig:
@@ -96,7 +91,7 @@ def dynamo_trace(
9691
aten_graph: bool,
9792
tracing_mode: str = "real",
9893
dynamo_config: Optional[DynamoConfig] = None,
99-
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
94+
) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
10095
"""
10196
TODO: Once we fully migrate to torchdynamo frontend, we will remove
10297
this config option alltogether. For now, it helps with quick
@@ -150,11 +145,11 @@ def trace(
150145
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
151146
print(fx_module.graph)
152147
for passes in passes_list:
153-
pr: PassResult = passes(fx_module) # type: ignore[assignment] #The type hints in fx are wrong
148+
pr: PassResult = passes(fx_module)
154149
fx_module = pr.graph_module
155150

156151
fx_module(*inputs)
157152

158153
fx_module = run_const_fold(fx_module)
159154
print(fx_module.graph)
160-
return fx_module # type: ignore[no-any-return]
155+
return fx_module
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .backends import torch_tensorrt_backend # noqa: F401
1+
from .backends import torch_tensorrt_backend # noqa: F401

py/torch_tensorrt/dynamo/backend/backends.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,12 @@
44

55
import torch
66
import torch._dynamo as td
7-
<<<<<<< HEAD
8-
7+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
98
from torch_tensorrt.dynamo import CompilationSettings
10-
from torch_tensorrt.dynamo.lowering._decompositions import (
11-
get_decompositions,
12-
)
13-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
14-
pre_aot_substitutions,
15-
)
16-
from torch_tensorrt.dynamo.lowering._partition import (
17-
partition,
18-
get_submod_inputs,
19-
)
20-
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
219
from torch_tensorrt.dynamo.conversion import (
2210
convert_module,
2311
repair_long_or_double_inputs,
2412
)
25-
26-
=======
27-
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
28-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
29-
from torch_tensorrt.dynamo import CompilationSettings
30-
from torch_tensorrt.dynamo.conversion import convert_module
3113
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
3214
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
3315
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
@@ -64,7 +46,7 @@ def aot_torch_tensorrt_aten_backend(
6446
return aot_module_simplified(
6547
gm,
6648
sample_inputs,
67-
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
49+
fw_compiler=make_boxed_compiler(custom_backend),
6850
decompositions=get_decompositions(),
6951
)
7052

py/torch_tensorrt/dynamo/compile.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch_tensorrt
77
from torch.fx.passes.pass_manager import PassManager
88
from torch.fx.passes.splitter_base import SplitResult
9-
from torch_tensorrt import Device, EngineCapability
9+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
10+
from torch_tensorrt._Device import Device
11+
from torch_tensorrt._enums import (
12+
EngineCapability,
13+
) # TODO: Should probabably be the TRT EngineCapability Enum
1014
from torch_tensorrt.dynamo import CompilationSettings
1115
from torch_tensorrt.dynamo._defaults import (
1216
DEBUG,
@@ -15,32 +19,18 @@
1519
OPTIMIZATION_LEVEL,
1620
PASS_THROUGH_BUILD_FAILURES,
1721
PRECISION,
22+
TRUNCATE_LONG_AND_DOUBLE,
1823
USE_PYTHON_RUNTIME,
1924
VERSION_COMPATIBLE,
2025
WORKSPACE_SIZE,
2126
)
2227
from torch_tensorrt.dynamo.backend.backends import _compile_module
2328
from torch_tensorrt.dynamo.conversion import convert_module
24-
<<<<<<< HEAD
25-
26-
from torch_tensorrt.dynamo._defaults import (
27-
PRECISION,
28-
DEBUG,
29-
WORKSPACE_SIZE,
30-
MIN_BLOCK_SIZE,
31-
PASS_THROUGH_BUILD_FAILURES,
32-
MAX_AUX_STREAMS,
33-
VERSION_COMPATIBLE,
34-
OPTIMIZATION_LEVEL,
35-
USE_PYTHON_RUNTIME,
36-
TRUNCATE_LONG_AND_DOUBLE,
29+
from torch_tensorrt.dynamo.lowering._fusers import (
30+
fuse_permute_linear,
31+
fuse_permute_matmul,
3732
)
38-
39-
=======
40-
from torch_tensorrt.dynamo.lowering import fuse_permute_linear, fuse_permute_matmul
4133
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
42-
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
43-
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
4434

4535
logger = logging.getLogger(__name__)
4636

@@ -89,7 +79,7 @@ def compile(
8979
if not isinstance(inputs, collections.abc.Sequence):
9080
inputs = [inputs]
9181

92-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
82+
_, torch_inputs = prepare_inputs(inputs, prepare_device(device))
9383

9484
if (
9585
torch.float16 in enabled_precisions
@@ -125,7 +115,7 @@ def compile(
125115
"truncate_long_and_double": truncate_long_and_double,
126116
}
127117

128-
settings = CompilationSettings(**compilation_options) # type: ignore[arg-type]
118+
settings = CompilationSettings(**compilation_options)
129119
if kwargs.get("use_capability_partitioner", None):
130120
model = lower_model(gm, torch_inputs)
131121
return _compile_module(model, torch_inputs, settings)
@@ -163,7 +153,7 @@ def lower_model_using_trt_splitter(
163153
) -> SplitResult:
164154
# Perform basic lowering
165155
model = lower_model(model, inputs)
166-
splitter_setting = TRTSplitterSetting() # type: ignore[no-untyped-call]
156+
splitter_setting = TRTSplitterSetting()
167157
splitter_setting.use_implicit_batch_dim = False
168158
splitter_setting.min_acc_module_size = 1
169159
splitter_setting.use_experimental_rt = False
@@ -177,7 +167,7 @@ def lower_model_using_trt_splitter(
177167
def lower_model(
178168
model: torch.nn.Module, inputs: Any, **kwargs: Any
179169
) -> torch.fx.GraphModule:
180-
graph_optimization_pm = PassManager.build_from_passlist( # type: ignore[no-untyped-call]
170+
graph_optimization_pm = PassManager.build_from_passlist(
181171
[fuse_permute_matmul, fuse_permute_linear]
182172
)
183173
lowered_model: torch.fx.GraphModule = graph_optimization_pm(model)

0 commit comments

Comments
 (0)