Skip to content

Commit 68e6aa8

Browse files
committed
chore: All code is now flake8 compliant
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 146b855 commit 68e6aa8

33 files changed

+283
-178
lines changed

.pre-commit-config.yaml

+5-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ repos:
3030
args:
3131
- --warnings=all
3232
- id: buildifier-lint
33+
- repo: https://github.com/astral-sh/ruff-pre-commit
34+
# Ruff version.
35+
rev: v0.0.278
36+
hooks:
37+
- id: ruff
3338
- repo: https://github.com/abravalheri/validate-pyproject
3439
rev: v0.13
3540
hooks:
@@ -47,4 +52,3 @@ repos:
4752
exclude: "^.pre-commit-config.yaml"
4853
language: pygrep
4954
types: [text]
50-

py/torch_tensorrt/_Device.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar, Optional, Any, Tuple
1+
from typing import Optional, Any, Tuple
22
import sys
33

44
if sys.version_info >= (3, 11):
@@ -15,7 +15,7 @@
1515

1616
try:
1717
from torch_tensorrt import _C
18-
except:
18+
except ImportError:
1919
warnings.warn(
2020
"Unable to import torchscript frontend core and torch-tensorrt runtime. Some dependent features may be unavailable."
2121
)

py/torch_tensorrt/_Input.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple, Optional, Union, Sequence
2+
from typing import List, Dict, Any, Tuple, Optional, Sequence
33

44
import torch
55

@@ -97,7 +97,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
9797
self.shape_mode = Input._ShapeMode.STATIC
9898

9999
elif len(args) == 0:
100-
if not ("shape" in kwargs) and not (
100+
if "shape" not in kwargs and not (
101101
all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])
102102
):
103103
raise ValueError(
@@ -298,8 +298,8 @@ def _parse_tensor_domain(
298298
domain_lo, domain_hi = domain
299299

300300
# Validate type and provided values for domain
301-
valid_type_lo = isinstance(domain_lo, int) or isinstance(domain_lo, float)
302-
valid_type_hi = isinstance(domain_hi, int) or isinstance(domain_hi, float)
301+
valid_type_lo = isinstance(domain_lo, (int, float))
302+
valid_type_hi = isinstance(domain_hi, (int, float))
303303

304304
if not valid_type_lo:
305305
raise ValueError(
@@ -405,12 +405,10 @@ def example_tensor(
405405
if optimization_profile_field is not None:
406406
try:
407407
assert any(
408-
[
409-
optimization_profile_field == field_name
410-
for field_name in ["min_shape", "opt_shape", "max_shape"]
411-
]
408+
optimization_profile_field == field_name
409+
for field_name in ["min_shape", "opt_shape", "max_shape"]
412410
)
413-
except:
411+
except AssertionError:
414412
raise ValueError(
415413
"Invalid field name, expected one of min_shape, opt_shape, max_shape"
416414
)

py/torch_tensorrt/__init__.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import ctypes
2-
import glob
32
import os
43
import sys
54
import platform
6-
import warnings
75
from packaging import version
86
from torch_tensorrt._version import (
9-
__version__,
107
__cuda_version__,
118
__cudnn_version__,
129
__tensorrt_version__,
@@ -38,8 +35,8 @@ def _find_lib(name: str, paths: List[str]) -> str:
3835

3936

4037
try:
41-
import tensorrt
42-
except:
38+
import tensorrt # noqa: F401
39+
except ImportError:
4340
cuda_version = _parse_semver(__cuda_version__)
4441
cudnn_version = _parse_semver(__cudnn_version__)
4542
tensorrt_version = _parse_semver(__tensorrt_version__)
@@ -85,20 +82,14 @@ def _find_lib(name: str, paths: List[str]) -> str:
8582

8683
import torch
8784

88-
from torch_tensorrt._compile import *
89-
from torch_tensorrt._util import *
90-
from torch_tensorrt import ts
91-
from torch_tensorrt import ptq
92-
from torch_tensorrt._enums import *
93-
from torch_tensorrt import logging
94-
from torch_tensorrt._Input import Input
95-
from torch_tensorrt._Device import Device
85+
from torch_tensorrt._compile import * # noqa: F403
86+
from torch_tensorrt._util import * # noqa: F403
87+
from torch_tensorrt._enums import * # noqa: F403
9688

97-
from torch_tensorrt import fx
9889

9990
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
100-
from torch_tensorrt import dynamo
101-
from torch_tensorrt.dynamo import backend
91+
from torch_tensorrt import dynamo # noqa: F401
92+
from torch_tensorrt.dynamo import backend # noqa: F401
10293

10394

10495
def _register_with_torch() -> None:

py/torch_tensorrt/_compile.py

+45-28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any, Set, Union, Callable, TypeGuard
1+
from typing import List, Any, Set, Callable, TypeGuard, Optional
22

33
import torch_tensorrt.ts
44

@@ -9,21 +9,24 @@
99
import torch.fx
1010
from enum import Enum
1111

12-
import torch_tensorrt.fx
1312
from torch_tensorrt.fx import InputTensorSpec
1413
from torch_tensorrt.fx.utils import LowerPrecision
1514

15+
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
16+
from torch_tensorrt.fx.lower import compile as fx_compile
17+
from torch_tensorrt.ts._compiler import compile as torchscript_compile
18+
1619

1720
def _non_fx_input_interface(
1821
inputs: List[Input | torch.Tensor | InputTensorSpec],
1922
) -> TypeGuard[List[Input | torch.Tensor]]:
20-
return all([isinstance(i, torch.Tensor | Input) for i in inputs])
23+
return all(isinstance(i, torch.Tensor | Input) for i in inputs)
2124

2225

2326
def _fx_input_interface(
2427
inputs: List[Input | torch.Tensor | InputTensorSpec],
2528
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
26-
return all([isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs])
29+
return all(isinstance(i, torch.Tensor | InputTensorSpec) for i in inputs)
2730

2831

2932
class _IRType(Enum):
@@ -58,10 +61,10 @@ def _parse_module_type(module: Any) -> _ModuleType:
5861

5962

6063
def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
61-
module_is_tsable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.ts]])
62-
module_is_fxable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.fx]])
64+
module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts])
65+
module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx])
6366

64-
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
67+
ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"])
6568
ir_targets_fx = ir == "fx"
6669
ir_targets_dynamo = ir == "dynamo"
6770
ir_targets_torch_compile = ir == "torch_compile"
@@ -97,8 +100,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
97100
def compile(
98101
module: Any,
99102
ir: str = "default",
100-
inputs: List[Input | torch.Tensor | InputTensorSpec] = [],
101-
enabled_precisions: Set[torch.dtype | dtype] = set([torch.float]),
103+
inputs: Optional[List[Input | torch.Tensor | InputTensorSpec]] = None,
104+
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
102105
**kwargs: Any,
103106
) -> (
104107
torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any]
@@ -138,6 +141,11 @@ def compile(
138141
Returns:
139142
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
140143
"""
144+
input_list = inputs if inputs is not None else []
145+
enabled_precisions_set = (
146+
enabled_precisions if enabled_precisions is not None else {torch.float}
147+
)
148+
141149
module_type = _parse_module_type(module)
142150
target_ir = _get_target_ir(module_type, ir)
143151
if target_ir == _IRType.ts:
@@ -148,45 +156,50 @@ def compile(
148156
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
149157
)
150158
ts_mod = torch.jit.script(module)
151-
assert _non_fx_input_interface(inputs)
152-
compiled_ts_module: torch.jit.ScriptModule = torch_tensorrt.ts.compile(
153-
ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
159+
assert _non_fx_input_interface(input_list)
160+
compiled_ts_module: torch.jit.ScriptModule = torchscript_compile(
161+
ts_mod,
162+
inputs=input_list,
163+
enabled_precisions=enabled_precisions_set,
164+
**kwargs,
154165
)
155166
return compiled_ts_module
156167
elif target_ir == _IRType.fx:
157168
if (
158-
torch.float16 in enabled_precisions
159-
or torch_tensorrt.dtype.half in enabled_precisions
169+
torch.float16 in enabled_precisions_set
170+
or torch_tensorrt.dtype.half in enabled_precisions_set
160171
):
161172
lower_precision = LowerPrecision.FP16
162173
elif (
163-
torch.float32 in enabled_precisions
164-
or torch_tensorrt.dtype.float in enabled_precisions
174+
torch.float32 in enabled_precisions_set
175+
or torch_tensorrt.dtype.float in enabled_precisions_set
165176
):
166177
lower_precision = LowerPrecision.FP32
167178
else:
168-
raise ValueError(f"Precision {enabled_precisions} not supported on FX")
179+
raise ValueError(f"Precision {enabled_precisions_set} not supported on FX")
169180

170-
assert _fx_input_interface(inputs)
171-
compiled_fx_module: torch.nn.Module = torch_tensorrt.fx.compile(
181+
assert _fx_input_interface(input_list)
182+
compiled_fx_module: torch.nn.Module = fx_compile(
172183
module,
173-
inputs,
184+
input_list,
174185
lower_precision=lower_precision,
175186
explicit_batch_dimension=True,
176187
dynamic_batch=False,
177188
**kwargs,
178189
)
179190
return compiled_fx_module
180191
elif target_ir == _IRType.dynamo:
181-
compiled_aten_module: torch.fx.GraphModule = torch_tensorrt.dynamo.compile(
192+
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
182193
module,
183-
inputs=inputs,
184-
enabled_precisions=enabled_precisions,
194+
inputs=input_list,
195+
enabled_precisions=enabled_precisions_set,
185196
**kwargs,
186197
)
187198
return compiled_aten_module
188199
elif target_ir == _IRType.torch_compile:
189-
return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs)
200+
return torch_compile(
201+
module, enabled_precisions=enabled_precisions_set, **kwargs
202+
)
190203
else:
191204
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
192205

@@ -206,10 +219,10 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
206219

207220
def convert_method_to_trt_engine(
208221
module: Any,
222+
inputs: List[Input | torch.Tensor],
209223
method_name: str,
210224
ir: str = "default",
211-
inputs: List[Input | torch.Tensor] = [],
212-
enabled_precisions: Set[torch.dtype | dtype] = set([torch.float]),
225+
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
213226
**kwargs: Any,
214227
) -> bytes:
215228
"""Convert a TorchScript module method to a serialized TensorRT engine
@@ -242,6 +255,10 @@ def convert_method_to_trt_engine(
242255
Returns:
243256
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
244257
"""
258+
enabled_precisions_set = (
259+
enabled_precisions if enabled_precisions is not None else {torch.float}
260+
)
261+
245262
module_type = _parse_module_type(module)
246263
target_ir = _get_target_ir(module_type, ir)
247264
if target_ir == _IRType.ts:
@@ -254,9 +271,9 @@ def convert_method_to_trt_engine(
254271
ts_mod = torch.jit.script(module)
255272
return torch_tensorrt.ts.convert_method_to_trt_engine(
256273
ts_mod,
257-
method_name,
258274
inputs=inputs,
259-
enabled_precisions=enabled_precisions,
275+
method_name=method_name,
276+
enabled_precisions=enabled_precisions_set,
260277
**kwargs,
261278
)
262279
elif target_ir == _IRType.fx:

py/torch_tensorrt/_enums.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
from torch_tensorrt._C import dtype, EngineCapability, TensorFormat
2-
from tensorrt import DeviceType

py/torch_tensorrt/dynamo/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from torch_tensorrt._util import sanitized_torch_version
33

44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
5-
from ._settings import *
6-
from .conversion import *
5+
from ._settings import * # noqa: F403
6+
from .conversion import * # noqa: F403
77
from .aten_tracer import trace
88
from .conversion.converter_registry import (
99
DYNAMO_CONVERTERS,

py/torch_tensorrt/dynamo/_settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Optional, Sequence, Set
2+
from typing import Optional, Set
33
import torch
44
from torch_tensorrt.dynamo._defaults import (
55
PRECISION,

py/torch_tensorrt/dynamo/aten_tracer.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import sys
33
from contextlib import contextmanager
44
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
5-
from packaging import version
65

76
import torch
87
import torch._dynamo as torchdynamo
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from .backends import torch_tensorrt_backend

py/torch_tensorrt/dynamo/backend/backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _pretraced_backend(
8383
settings=settings,
8484
)
8585
return trt_compiled
86-
except:
86+
except AssertionError:
8787
if not settings.pass_through_build_failures:
8888
logger.warning(
8989
"TRT conversion failed on the subgraph. See trace above. "

0 commit comments

Comments
 (0)