Skip to content

Commit 5ff754f

Browse files
committed
chore: adding isort to pre-commit
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e6a4c08 commit 5ff754f

30 files changed

+162
-167
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ repos:
4444
hooks:
4545
- id: mypy
4646
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools"
47+
python_version: "3.11"
4748
- repo: local
4849
hooks:
4950
- id: dont-commit-upstream

py/torch_tensorrt/_Device.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
from typing import Optional, Any, Tuple
21
import sys
2+
from typing import Any, Optional, Tuple
33

44
if sys.version_info >= (3, 11):
55
from typing import Self
66
else:
77
from typing_extensions import Self
88

9+
import warnings
10+
911
import torch
12+
from torch_tensorrt import logging
1013

1114
# from torch_tensorrt import _enums
1215
import tensorrt as trt
13-
from torch_tensorrt import logging
14-
import warnings
1516

1617
try:
1718
from torch_tensorrt import _C

py/torch_tensorrt/_Input.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple, Optional, Sequence
2+
from typing import Any, Dict, List, Optional, Sequence, Tuple
33

44
import torch
5-
65
from torch_tensorrt import _enums
76

87

py/torch_tensorrt/__init__.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import ctypes
22
import os
3-
import sys
43
import platform
4+
import sys
5+
from typing import Dict, List
6+
57
from packaging import version
68
from torch_tensorrt._version import (
79
__cuda_version__,
810
__cudnn_version__,
911
__tensorrt_version__,
1012
)
1113

12-
from typing import Dict, List
13-
1414
if sys.version_info < (3,):
1515
raise Exception(
1616
"Python 2 has reached end-of-life and is not supported by Torch-TensorRT"
@@ -81,11 +81,9 @@ def _find_lib(name: str, paths: List[str]) -> str:
8181
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))
8282

8383
import torch
84-
8584
from torch_tensorrt._compile import * # noqa: F403
86-
from torch_tensorrt._util import * # noqa: F403
8785
from torch_tensorrt._enums import * # noqa: F403
88-
86+
from torch_tensorrt._util import * # noqa: F403
8987

9088
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9189
from torch_tensorrt import dynamo # noqa: F401

py/torch_tensorrt/_compile.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
from typing import List, Any, Set, Callable, TypeGuard, Optional
1+
from enum import Enum
2+
from typing import Any, Callable, List, Optional, Set, TypeGuard
23

4+
import torch
5+
import torch.fx
36
import torch_tensorrt.ts
4-
57
from torch_tensorrt import logging
6-
from torch_tensorrt._Input import Input
78
from torch_tensorrt._enums import dtype
8-
import torch
9-
import torch.fx
10-
from enum import Enum
11-
12-
from torch_tensorrt.fx import InputTensorSpec
13-
from torch_tensorrt.fx.utils import LowerPrecision
14-
9+
from torch_tensorrt._Input import Input
1510
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
11+
from torch_tensorrt.fx import InputTensorSpec
1612
from torch_tensorrt.fx.lower import compile as fx_compile
13+
from torch_tensorrt.fx.utils import LowerPrecision
1714
from torch_tensorrt.ts._compiler import compile as torchscript_compile
1815

1916

py/torch_tensorrt/_enums.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from torch_tensorrt._C import dtype, EngineCapability, TensorFormat # noqa: F401
2+
from tensorrt import DeviceType # noqa: F401

py/torch_tensorrt/_util.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from torch_tensorrt import __version__
2-
from torch_tensorrt import _C
3-
41
import torch
2+
from torch_tensorrt import _C, __version__
53

64

75
def dump_build_info() -> None:

py/torch_tensorrt/dynamo/_settings.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Set
3+
34
import torch
45
from torch_tensorrt.dynamo._defaults import (
5-
PRECISION,
66
DEBUG,
7-
WORKSPACE_SIZE,
8-
MIN_BLOCK_SIZE,
9-
PASS_THROUGH_BUILD_FAILURES,
107
MAX_AUX_STREAMS,
11-
VERSION_COMPATIBLE,
8+
MIN_BLOCK_SIZE,
129
OPTIMIZATION_LEVEL,
10+
PASS_THROUGH_BUILD_FAILURES,
11+
PRECISION,
1312
USE_PYTHON_RUNTIME,
13+
VERSION_COMPATIBLE,
14+
WORKSPACE_SIZE,
1415
)
1516

1617

py/torch_tensorrt/dynamo/aten_tracer.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch._dynamo as torchdynamo
88
from torch import _guards
99
from torch.fx.passes.infra.pass_base import PassResult
10-
1110
from torch_tensorrt.dynamo.utils import req_torch_version
1211
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
1312
compose_bmm,

py/torch_tensorrt/dynamo/backend/backends.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
11
import logging
2-
from typing import Sequence, Any, Callable
3-
import torch
42
from functools import partial
5-
import torch._dynamo as td
3+
from typing import Any, Callable, Sequence
64

5+
import torch
6+
import torch._dynamo as td
7+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
78
from torch_tensorrt.dynamo import CompilationSettings
8-
from torch_tensorrt.dynamo.lowering._decompositions import (
9-
get_decompositions,
10-
)
11-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
12-
pre_aot_substitutions,
13-
)
14-
from torch_tensorrt.dynamo.lowering._partition import (
15-
partition,
16-
get_submod_inputs,
17-
)
18-
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
199
from torch_tensorrt.dynamo.conversion import convert_module
20-
21-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
22-
10+
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
11+
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
12+
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
13+
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
2314

2415
logger = logging.getLogger(__name__)
2516

py/torch_tensorrt/dynamo/compile.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,29 @@
1-
import torch
2-
import logging
31
import collections.abc
4-
import torch_tensorrt
2+
import logging
3+
from typing import Any, List, Optional, Set, Tuple
54

6-
from typing import Any, Optional, Set, List, Tuple
7-
from torch_tensorrt import EngineCapability, Device
5+
import torch
6+
import torch_tensorrt
87
from torch.fx.passes.pass_manager import PassManager
98
from torch.fx.passes.splitter_base import SplitResult
10-
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
11-
from torch_tensorrt.dynamo.lowering import (
12-
fuse_permute_linear,
13-
fuse_permute_matmul,
14-
)
9+
from torch_tensorrt import Device, EngineCapability
1510
from torch_tensorrt.dynamo import CompilationSettings
16-
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
17-
from torch_tensorrt.dynamo.backend.backends import _compile_module
18-
from torch_tensorrt.dynamo.conversion import convert_module
19-
2011
from torch_tensorrt.dynamo._defaults import (
21-
PRECISION,
2212
DEBUG,
23-
WORKSPACE_SIZE,
24-
MIN_BLOCK_SIZE,
25-
PASS_THROUGH_BUILD_FAILURES,
2613
MAX_AUX_STREAMS,
27-
VERSION_COMPATIBLE,
14+
MIN_BLOCK_SIZE,
2815
OPTIMIZATION_LEVEL,
16+
PASS_THROUGH_BUILD_FAILURES,
17+
PRECISION,
2918
USE_PYTHON_RUNTIME,
19+
VERSION_COMPATIBLE,
20+
WORKSPACE_SIZE,
3021
)
31-
22+
from torch_tensorrt.dynamo.backend.backends import _compile_module
23+
from torch_tensorrt.dynamo.conversion import convert_module
24+
from torch_tensorrt.dynamo.lowering import fuse_permute_linear, fuse_permute_matmul
25+
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
26+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
3227

3328
logger = logging.getLogger(__name__)
3429

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
import logging
22
import warnings
33
from datetime import datetime
4-
from packaging import version
54
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
65

76
import numpy
8-
9-
# @manual=//deeplearning/trt/python:py_tensorrt
10-
import tensorrt as trt
117
import torch
128
import torch.fx
9+
from packaging import version
1310
from torch.fx.node import _get_qualified_name
1411
from torch.fx.passes.shape_prop import TensorMetadata
15-
16-
from torch_tensorrt.fx import CONVERTERS
1712
from torch_tensorrt import Input
13+
from torch_tensorrt.fx import CONVERTERS
1814
from torch_tensorrt.fx.observer import Observer
19-
from torch_tensorrt.fx.utils import (
20-
unified_dtype_converter,
21-
Frameworks,
22-
)
15+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
16+
17+
# @manual=//deeplearning/trt/python:py_tensorrt
18+
import tensorrt as trt
2319

2420
_LOGGER: logging.Logger = logging.getLogger(__name__)
2521

py/torch_tensorrt/dynamo/conversion/conversion.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import io
12
from typing import Sequence
3+
24
import torch
3-
import io
4-
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
5-
from torch_tensorrt.dynamo import CompilationSettings
65
from torch_tensorrt import Input
6+
from torch_tensorrt.dynamo import CompilationSettings
77
from torch_tensorrt.dynamo.conversion import TRTInterpreter
8-
8+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
99

1010
import tensorrt as trt
1111

Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .substitutions import * # noqa: F403
21
from ._fusers import * # noqa: F403
2+
from .substitutions import * # noqa: F403

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Callable, Dict, Any
2-
import torch
3-
from torch._decomp import register_decomposition, core_aten_decompositions, OpOverload
1+
from typing import Any, Callable, Dict
42

3+
import torch
4+
from torch._decomp import OpOverload, core_aten_decompositions, register_decomposition
55

66
DECOMPOSITIONS: Dict[OpOverload, Callable[..., Any]] = {**core_aten_decompositions()}
77

py/torch_tensorrt/dynamo/lowering/_partition.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import logging
2-
from typing import List, Optional, Sequence, Set, Mapping
2+
from typing import List, Mapping, Optional, Sequence, Set
33

44
import torch
5-
6-
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
7-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
8-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
95
from torch.fx.graph_module import GraphModule
106
from torch.fx.node import _get_qualified_name
7+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
118
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
12-
9+
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
10+
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
1311
from torch_tensorrt.fx.converter_registry import CONVERTERS
1412

15-
1613
logger = logging.getLogger(__name__)
1714

1815
DEFAULT_SINGLE_NODE_PARTITIONS: List[str] = [

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
import logging
12
from dataclasses import dataclass
23
from typing import Any, Callable, Dict, Optional, Type, TypeAlias
3-
import torch
4-
import logging
54

6-
from torch.fx import GraphModule, Node
5+
import torch
76
from torch._ops import OpOverload
8-
7+
from torch.fx import GraphModule, Node
98

109
logger = logging.getLogger(__name__)
1110

Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .maxpool1d import * # noqa: F403
21
from .einsum import * # noqa: F403
2+
from .maxpool1d import * # noqa: F403

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from typing import Dict, Tuple, Any, Optional, Sequence
1+
from typing import Any, Dict, Optional, Sequence, Tuple
2+
23
import torch
34
from torch._custom_op.impl import custom_op
45
from torch.fx.node import Argument, Target
5-
6+
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
67
from torch_tensorrt.fx.converter_registry import tensorrt_converter
78
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
89
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
910

10-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
11-
1211

1312
@custom_op(
1413
qualname="tensorrt::einsum",

py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from typing import Dict, Tuple, Any, Optional
1+
from typing import Any, Dict, Optional, Tuple
2+
23
import torch
34
from torch._custom_op.impl import custom_op
45
from torch.fx.node import Argument, Target
5-
6+
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
67
from torch_tensorrt.fx.converter_registry import tensorrt_converter
78
from torch_tensorrt.fx.converters import acc_ops_converters
89
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
910

10-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
11-
12-
1311
# This file serves as an example and a tutorial for excluding custom modules from
1412
# torch.compile tracing. Each required step is labeled with a number indicating the
1513
# preferable implementation order.

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, List, Sequence, Dict, Tuple, Optional
1+
from typing import Any, Dict, List, Optional, Sequence, Tuple
2+
3+
import torch
4+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
25

36
# @manual=//deeplearning/trt/python:py_tensorrt
47
import tensorrt as trt
5-
import torch
6-
from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks
78

89

910
class PythonTorchTensorRTModule(torch.nn.Module):

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, List, Tuple, Optional
2+
from typing import Any, List, Optional, Tuple
33

44
import torch
55
from torch_tensorrt._Device import Device
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from ._PythonTorchTensorRTModule import PythonTorchTensorRTModule # noqa: F401
2+
from ._TorchTensorRTModule import TorchTensorRTModule # noqa: F401

0 commit comments

Comments
 (0)