Skip to content

Commit 002a5c7

Browse files
committed
chore: ready to start review
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 15d7be1 commit 002a5c7

File tree

102 files changed

+569
-609
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+569
-609
lines changed

.pre-commit-config.yaml

+18-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ repos:
55
hooks:
66
- id: check-yaml
77
- id: trailing-whitespace
8+
exclude: ^docs
89
- id: check-added-large-files
910
args:
1011
- --maxkb=1000
@@ -13,11 +14,7 @@ repos:
1314
- id: mixed-line-ending
1415
args:
1516
- --fix=lf
16-
- repo: https://github.com/psf/black
17-
rev: 23.7.0
18-
hooks:
19-
- id: black
20-
exclude: ^examples/custom_converters/elu_converter/setup.py
17+
exclude: ^docs
2118
- repo: https://github.com/pre-commit/mirrors-clang-format
2219
rev: v16.0.6
2320
hooks:
@@ -30,21 +27,30 @@ repos:
3027
args:
3128
- --warnings=all
3229
- id: buildifier-lint
33-
- repo: https://github.com/astral-sh/ruff-pre-commit
34-
# Ruff version.
35-
rev: v0.0.278
36-
hooks:
37-
- id: ruff
3830
- repo: https://github.com/abravalheri/validate-pyproject
3931
rev: v0.13
4032
hooks:
4133
- id: validate-pyproject
34+
- repo: https://github.com/pycqa/isort
35+
rev: 5.12.0
36+
hooks:
37+
- id: isort
38+
name: isort (python)
4239
- repo: https://github.com/pre-commit/mirrors-mypy
4340
rev: 'v1.4.1'
4441
hooks:
4542
- id: mypy
46-
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools"
47-
python_version: "3.11"
43+
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools|^docs|noxfile.py|setup.py|versions.py"
44+
- repo: https://github.com/astral-sh/ruff-pre-commit
45+
# Ruff version.
46+
rev: v0.0.278
47+
hooks:
48+
- id: ruff
49+
- repo: https://github.com/psf/black
50+
rev: 23.7.0
51+
hooks:
52+
- id: black
53+
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
4854
- repo: local
4955
hooks:
5056
- id: dont-commit-upstream

core/conversion/evaluators/aten.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,12 @@ DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
104104
"aten::pow.float_int(float a, int b) -> (float)",
105105
}));
106106

107-
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(and, "aten::__and__", a&& b, bool, std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
107+
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
108+
and,
109+
"aten::__and__",
110+
a&& b,
111+
bool,
112+
std::set<std::string>({"aten::__and__(int a, int b) -> (bool)", "aten::__and__.bool(bool a, bool b) -> (bool)"}));
108113
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(or, "aten::__or__", a || b, bool, {"aten::__or__(int a, int b) -> (bool)"});
109114
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
110115
xor,

docs/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# %%
1818

19+
1920
# We begin by defining a model
2021
class Model(torch.nn.Module):
2122
def __init__(self) -> None:

docs/v1.4.0/_downloads/e1ef5a42560a98a132f56a79d0b66f79/dynamo_compile_advanced_usage.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# %%
1818

19+
1920
# We begin by defining a model
2021
class Model(torch.nn.Module):
2122
def __init__(self) -> None:

py/torch_tensorrt/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
import sys
55
from typing import Dict, List
66

7-
from packaging import version
87
from torch_tensorrt._version import (
98
__cuda_version__,
109
__cudnn_version__,
1110
__tensorrt_version__,
1211
)
1312

13+
from packaging import version
14+
1415
if sys.version_info < (3,):
1516
raise Exception(
1617
"Python 2 has reached end-of-life and is not supported by Torch-TensorRT"
@@ -82,10 +83,10 @@ def _find_lib(name: str, paths: List[str]) -> str:
8283

8384
import torch
8485
from torch_tensorrt._compile import * # noqa: F403
86+
from torch_tensorrt._Device import Device # noqa: F401
8587
from torch_tensorrt._enums import * # noqa: F403
86-
from torch_tensorrt._util import * # noqa: F403
87-
from torch_tensorrt._Input import Input # noqa: F401
88-
from torch_tensorrt._Device import Device # noqa: F401
88+
from torch_tensorrt._Input import Input # noqa: F401
89+
from torch_tensorrt._utils import * # noqa: F403
8990

9091
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9192
from torch_tensorrt import dynamo # noqa: F401

py/torch_tensorrt/_compile.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, Callable, List, Optional, Set, TypeGuard
2+
from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard
33

44
import torch
55
import torch.fx
@@ -15,13 +15,13 @@
1515

1616

1717
def _non_fx_input_interface(
18-
inputs: List[Input | torch.Tensor | InputTensorSpec],
18+
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
1919
) -> TypeGuard[List[Input | torch.Tensor]]:
2020
return all(isinstance(i, torch.Tensor | Input) for i in inputs)
2121

2222

2323
def _fx_input_interface(
24-
inputs: List[Input | torch.Tensor | InputTensorSpec],
24+
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
2525
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
2626
return all(isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs)
2727

@@ -97,7 +97,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
9797
def compile(
9898
module: Any,
9999
ir: str = "default",
100-
inputs: Optional[List[Input | torch.Tensor | InputTensorSpec]] = None,
100+
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
101101
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
102102
**kwargs: Any,
103103
) -> (
@@ -201,7 +201,7 @@ def compile(
201201
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
202202

203203

204-
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
204+
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
205205
"""
206206
Returns a boxed model which is the output of torch.compile.
207207
This does not compile the model to TRT. Execute this model on
@@ -216,8 +216,8 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
216216

217217
def convert_method_to_trt_engine(
218218
module: Any,
219-
inputs: List[Input | torch.Tensor],
220-
method_name: str,
219+
method_name: str = "forward",
220+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
221221
ir: str = "default",
222222
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
223223
**kwargs: Any,
@@ -266,7 +266,7 @@ def convert_method_to_trt_engine(
266266
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
267267
)
268268
ts_mod = torch.jit.script(module)
269-
return torch_tensorrt.ts.convert_method_to_trt_engine(
269+
return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return]
270270
ts_mod,
271271
inputs=inputs,
272272
method_name=method_name,

py/torch_tensorrt/_enums.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from torch_tensorrt._C import dtype, EngineCapability, TensorFormat # noqa: F401
1+
from torch_tensorrt._C import EngineCapability, TensorFormat, dtype # noqa: F401
2+
23
from tensorrt import DeviceType # noqa: F401

py/torch_tensorrt/_util.py renamed to py/torch_tensorrt/_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from typing import Any
2+
13
import torch
2-
from torch_tensorrt import _C, __version__
4+
from torch_tensorrt import _C
5+
from torch_tensorrt._version import __version__
36

47

58
def dump_build_info() -> None:
@@ -30,7 +33,7 @@ def set_device(gpu_id: int) -> None:
3033
_C.set_device(gpu_id)
3134

3235

33-
def sanitized_torch_version() -> str:
36+
def sanitized_torch_version() -> Any:
3437
return (
3538
torch.__version__
3639
if ".nv" not in torch.__version__

py/torch_tensorrt/dynamo/_SourceIR.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class SourceIR(Enum):
99
TORCHTRT_LOWERED = auto()
1010
UNKNOWN = auto()
1111

12-
def __str__(self):
12+
def __str__(self) -> str:
1313
if self == SourceIR.NN:
1414
return "nn"
1515
elif self == SourceIR.ACC:

py/torch_tensorrt/dynamo/__init__.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
from torch_tensorrt._utils import sanitized_torch_version
2+
13
from packaging import version
2-
from torch_tensorrt._util import sanitized_torch_version
34

45
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
56
from ._settings import * # noqa: F403
6-
from .conversion import * # noqa: F403
7-
from .aten_tracer import trace # noqa: F403
8-
from .conversion.converter_registry import (
9-
DYNAMO_CONVERTERS, # noqa: F403
10-
dynamo_tensorrt_converter, # noqa: F403
11-
)
12-
from .compile import compile # noqa: F403
13-
from ._SourceIR import SourceIR # noqa: F403
7+
from ._SourceIR import SourceIR # noqa: F403
8+
from .aten_tracer import trace # noqa: F403
9+
from .compile import compile # noqa: F403
10+
from .conversion import * # noqa: F403
11+
from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403
12+
from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403

py/torch_tensorrt/dynamo/_settings.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@
99
OPTIMIZATION_LEVEL,
1010
PASS_THROUGH_BUILD_FAILURES,
1111
PRECISION,
12-
USE_PYTHON_RUNTIME,
13-
<<<<<<< HEAD
1412
TRUNCATE_LONG_AND_DOUBLE,
15-
=======
13+
USE_PYTHON_RUNTIME,
1614
VERSION_COMPATIBLE,
1715
WORKSPACE_SIZE,
18-
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
1916
)
2017

2118

py/torch_tensorrt/dynamo/aten_tracer.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import copy
22
import sys
33
from contextlib import contextmanager
4-
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
4+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
55

66
import torch
77
import torch._dynamo as torchdynamo
8-
from torch import _guards
98
from torch.fx.passes.infra.pass_base import PassResult
109
from torch_tensorrt.dynamo.utils import req_torch_version
1110
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
@@ -23,11 +22,7 @@
2322
)
2423
from typing_extensions import TypeAlias
2524

26-
Value: TypeAlias = Union[
27-
Tuple["Value", ...],
28-
List["Value"],
29-
Dict[str, "Value"],
30-
]
25+
Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"]
3126

3227

3328
class DynamoConfig:
@@ -96,7 +91,7 @@ def dynamo_trace(
9691
aten_graph: bool,
9792
tracing_mode: str = "real",
9893
dynamo_config: Optional[DynamoConfig] = None,
99-
) -> Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
94+
) -> Any: # Tuple[torch.fx.GraphModule, Set[_guards.Guard]]:
10095
"""
10196
TODO: Once we fully migrate to torchdynamo frontend, we will remove
10297
this config option alltogether. For now, it helps with quick
@@ -150,11 +145,11 @@ def trace(
150145
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
151146
print(fx_module.graph)
152147
for passes in passes_list:
153-
pr: PassResult = passes(fx_module) # type: ignore[assignment] #The type hints in fx are wrong
148+
pr: PassResult = passes(fx_module)
154149
fx_module = pr.graph_module
155150

156151
fx_module(*inputs)
157152

158153
fx_module = run_const_fold(fx_module)
159154
print(fx_module.graph)
160-
return fx_module # type: ignore[no-any-return]
155+
return fx_module
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .backends import torch_tensorrt_backend # noqa: F401
1+
from .backends import torch_tensorrt_backend # noqa: F401

py/torch_tensorrt/dynamo/backend/backends.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,12 @@
44

55
import torch
66
import torch._dynamo as td
7-
<<<<<<< HEAD
8-
7+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
98
from torch_tensorrt.dynamo import CompilationSettings
10-
from torch_tensorrt.dynamo.lowering._decompositions import (
11-
get_decompositions,
12-
)
13-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
14-
pre_aot_substitutions,
15-
)
16-
from torch_tensorrt.dynamo.lowering._partition import (
17-
partition,
18-
get_submod_inputs,
19-
)
20-
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
219
from torch_tensorrt.dynamo.conversion import (
2210
convert_module,
2311
repair_long_or_double_inputs,
2412
)
25-
26-
=======
27-
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
28-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
29-
from torch_tensorrt.dynamo import CompilationSettings
30-
from torch_tensorrt.dynamo.conversion import convert_module
3113
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
3214
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
3315
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
@@ -64,7 +46,7 @@ def aot_torch_tensorrt_aten_backend(
6446
return aot_module_simplified(
6547
gm,
6648
sample_inputs,
67-
fw_compiler=make_boxed_compiler(custom_backend), # type: ignore[no-untyped-call]
49+
fw_compiler=make_boxed_compiler(custom_backend),
6850
decompositions=get_decompositions(),
6951
)
7052

0 commit comments

Comments
 (0)