Skip to content

Commit 6f802d1

Browse files
committed
chore: adding isort to pre-commit
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 68e6aa8 commit 6f802d1

33 files changed

+206
-140
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ repos:
4444
hooks:
4545
- id: mypy
4646
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^tools"
47+
python_version: "3.11"
4748
- repo: local
4849
hooks:
4950
- id: dont-commit-upstream

py/torch_tensorrt/_Device.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
from typing import Optional, Any, Tuple
21
import sys
2+
from typing import Any, Optional, Tuple
33

44
if sys.version_info >= (3, 11):
55
from typing import Self
66
else:
77
from typing_extensions import Self
88

9+
import warnings
10+
911
import torch
12+
from torch_tensorrt import logging
1013

1114
# from torch_tensorrt import _enums
1215
import tensorrt as trt
13-
from torch_tensorrt import logging
14-
import warnings
1516

1617
try:
1718
from torch_tensorrt import _C

py/torch_tensorrt/_Input.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from enum import Enum
2-
from typing import List, Dict, Any, Tuple, Optional, Sequence
2+
from typing import Any, Dict, List, Optional, Sequence, Tuple
33

44
import torch
5-
65
from torch_tensorrt import _enums
76

87

py/torch_tensorrt/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import ctypes
22
import os
3-
import sys
43
import platform
4+
import sys
5+
from typing import Dict, List
6+
57
from packaging import version
68
from torch_tensorrt._version import (
79
__cuda_version__,
810
__cudnn_version__,
911
__tensorrt_version__,
1012
)
1113

12-
from typing import Dict, List
13-
1414
if sys.version_info < (3,):
1515
raise Exception(
1616
"Python 2 has reached end-of-life and is not supported by Torch-TensorRT"
@@ -81,11 +81,11 @@ def _find_lib(name: str, paths: List[str]) -> str:
8181
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))
8282

8383
import torch
84-
8584
from torch_tensorrt._compile import * # noqa: F403
86-
from torch_tensorrt._util import * # noqa: F403
8785
from torch_tensorrt._enums import * # noqa: F403
88-
86+
from torch_tensorrt._util import * # noqa: F403
87+
from torch_tensorrt._Input import Input # noqa: F401
88+
from torch_tensorrt._Device import Device # noqa: F401
8989

9090
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9191
from torch_tensorrt import dynamo # noqa: F401

py/torch_tensorrt/_compile.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
from typing import List, Any, Set, Callable, TypeGuard, Optional
1+
from enum import Enum
2+
from typing import Any, Callable, List, Optional, Set, TypeGuard
23

4+
import torch
5+
import torch.fx
36
import torch_tensorrt.ts
4-
57
from torch_tensorrt import logging
6-
from torch_tensorrt._Input import Input
78
from torch_tensorrt._enums import dtype
8-
import torch
9-
import torch.fx
10-
from enum import Enum
11-
12-
from torch_tensorrt.fx import InputTensorSpec
13-
from torch_tensorrt.fx.utils import LowerPrecision
14-
9+
from torch_tensorrt._Input import Input
1510
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
11+
from torch_tensorrt.fx import InputTensorSpec
1612
from torch_tensorrt.fx.lower import compile as fx_compile
13+
from torch_tensorrt.fx.utils import LowerPrecision
1714
from torch_tensorrt.ts._compiler import compile as torchscript_compile
1815

1916

py/torch_tensorrt/_enums.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from torch_tensorrt._C import dtype, EngineCapability, TensorFormat # noqa: F401
2+
from tensorrt import DeviceType # noqa: F401

py/torch_tensorrt/_util.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
from torch_tensorrt import __version__
2-
from torch_tensorrt import _C
3-
41
import torch
2+
from torch_tensorrt import _C, __version__
53

64

75
def dump_build_info() -> None:

py/torch_tensorrt/dynamo/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
55
from ._settings import * # noqa: F403
66
from .conversion import * # noqa: F403
7-
from .aten_tracer import trace
7+
from .aten_tracer import trace # noqa: F403
88
from .conversion.converter_registry import (
9-
DYNAMO_CONVERTERS,
10-
dynamo_tensorrt_converter,
9+
DYNAMO_CONVERTERS, # noqa: F403
10+
dynamo_tensorrt_converter, # noqa: F403
1111
)
12-
from .compile import compile
13-
from ._SourceIR import SourceIR
12+
from .compile import compile # noqa: F403
13+
from ._SourceIR import SourceIR # noqa: F403

py/torch_tensorrt/dynamo/_settings.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Set
3+
34
import torch
45
from torch_tensorrt.dynamo._defaults import (
5-
PRECISION,
66
DEBUG,
7-
WORKSPACE_SIZE,
8-
MIN_BLOCK_SIZE,
9-
PASS_THROUGH_BUILD_FAILURES,
107
MAX_AUX_STREAMS,
11-
VERSION_COMPATIBLE,
8+
MIN_BLOCK_SIZE,
129
OPTIMIZATION_LEVEL,
10+
PASS_THROUGH_BUILD_FAILURES,
11+
PRECISION,
1312
USE_PYTHON_RUNTIME,
13+
<<<<<<< HEAD
1414
TRUNCATE_LONG_AND_DOUBLE,
15+
=======
16+
VERSION_COMPATIBLE,
17+
WORKSPACE_SIZE,
18+
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
1519
)
1620

1721

py/torch_tensorrt/dynamo/aten_tracer.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch._dynamo as torchdynamo
88
from torch import _guards
99
from torch.fx.passes.infra.pass_base import PassResult
10-
1110
from torch_tensorrt.dynamo.utils import req_torch_version
1211
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
1312
compose_bmm,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .backends import torch_tensorrt_backend # noqa: F401

py/torch_tensorrt/dynamo/backend/backends.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
2-
from typing import Sequence, Any, Callable
3-
import torch
42
from functools import partial
3+
from typing import Any, Callable, Sequence
4+
5+
import torch
56
import torch._dynamo as td
7+
<<<<<<< HEAD
68

79
from torch_tensorrt.dynamo import CompilationSettings
810
from torch_tensorrt.dynamo.lowering._decompositions import (
@@ -21,8 +23,15 @@
2123
repair_long_or_double_inputs,
2224
)
2325

26+
=======
27+
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
2428
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
25-
29+
from torch_tensorrt.dynamo import CompilationSettings
30+
from torch_tensorrt.dynamo.conversion import convert_module
31+
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
32+
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
33+
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
34+
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
2635

2736
logger = logging.getLogger(__name__)
2837

py/torch_tensorrt/dynamo/compile.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
1-
import torch
2-
import logging
31
import collections.abc
4-
import torch_tensorrt
2+
import logging
3+
from typing import Any, List, Optional, Set, Tuple
54

6-
from typing import Any, Optional, Set, List, Tuple
7-
from torch_tensorrt import EngineCapability, Device
5+
import torch
6+
import torch_tensorrt
87
from torch.fx.passes.pass_manager import PassManager
98
from torch.fx.passes.splitter_base import SplitResult
10-
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
11-
from torch_tensorrt.dynamo.lowering import (
12-
fuse_permute_linear,
13-
fuse_permute_matmul,
14-
)
9+
from torch_tensorrt import Device, EngineCapability
1510
from torch_tensorrt.dynamo import CompilationSettings
16-
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
11+
from torch_tensorrt.dynamo._defaults import (
12+
DEBUG,
13+
MAX_AUX_STREAMS,
14+
MIN_BLOCK_SIZE,
15+
OPTIMIZATION_LEVEL,
16+
PASS_THROUGH_BUILD_FAILURES,
17+
PRECISION,
18+
USE_PYTHON_RUNTIME,
19+
VERSION_COMPATIBLE,
20+
WORKSPACE_SIZE,
21+
)
1722
from torch_tensorrt.dynamo.backend.backends import _compile_module
1823
from torch_tensorrt.dynamo.conversion import convert_module
24+
<<<<<<< HEAD
1925

2026
from torch_tensorrt.dynamo._defaults import (
2127
PRECISION,
@@ -30,6 +36,11 @@
3036
TRUNCATE_LONG_AND_DOUBLE,
3137
)
3238

39+
=======
40+
from torch_tensorrt.dynamo.lowering import fuse_permute_linear, fuse_permute_matmul
41+
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
42+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
43+
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
3344

3445
logger = logging.getLogger(__name__)
3546

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import logging
22
import warnings
33
from datetime import datetime
4-
from packaging import version
54
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
65

76
import numpy
7+
import torch
8+
import torch.fx
9+
from packaging import version
10+
from torch.fx.node import _get_qualified_name
11+
from torch.fx.passes.shape_prop import TensorMetadata
12+
from torch_tensorrt import Input
13+
from torch_tensorrt.fx import CONVERTERS
14+
from torch_tensorrt.fx.observer import Observer
15+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
816

917
# @manual=//deeplearning/trt/python:py_tensorrt
1018
import tensorrt as trt
19+
<<<<<<< HEAD
1120
import torch
1221
import torch.fx
1322
from torch.fx.node import _get_qualified_name
@@ -20,6 +29,8 @@
2029
unified_dtype_converter,
2130
Frameworks,
2231
)
32+
=======
33+
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
2334

2435
_LOGGER: logging.Logger = logging.getLogger(__name__)
2536

py/torch_tensorrt/dynamo/conversion/conversion.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import io
12
from typing import Sequence
3+
24
import torch
3-
import io
4-
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
5-
from torch_tensorrt.dynamo import CompilationSettings
65
from torch_tensorrt import Input
6+
from torch_tensorrt.dynamo import CompilationSettings
77
from torch_tensorrt.dynamo.conversion import TRTInterpreter
8-
8+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
99

1010
import tensorrt as trt
1111

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
1-
from .substitutions import * # noqa: F403
2-
from ._fusers import * # noqa: F403
1+
from ._decompositions import (
2+
get_decompositions, # noqa: F401
3+
)
4+
from ._pre_aot_lowering import (
5+
SUBSTITUTION_REGISTRY, # noqa: F401
6+
register_substitution, # noqa: F401
7+
)
8+
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS # noqa: F401
9+
from .substitutions import * # noqa: F403
10+
from ._fusers import * # noqa: F403

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Callable, Dict, Any
2-
import torch
3-
from torch._decomp import register_decomposition, core_aten_decompositions, OpOverload
1+
from typing import Any, Callable, Dict
42

3+
import torch
4+
from torch._decomp import OpOverload, core_aten_decompositions, register_decomposition
55

66
DECOMPOSITIONS: Dict[OpOverload, Callable[..., Any]] = {**core_aten_decompositions()}
77

py/torch_tensorrt/dynamo/lowering/_partition.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import logging
2-
from typing import List, Optional, Sequence, Set, Mapping
2+
from typing import List, Mapping, Optional, Sequence, Set
33

44
import torch
5-
6-
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
7-
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
8-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
95
from torch.fx.graph_module import GraphModule
106
from torch.fx.node import _get_qualified_name
7+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
118
from torch.fx.passes.operator_support import OperatorSupport, SupportDict
9+
<<<<<<< HEAD
1210

1311
from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS
14-
12+
=======
13+
from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE
14+
from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY
15+
from torch_tensorrt.fx.converter_registry import CONVERTERS
16+
>>>>>>> e39abb60d (chore: adding isort to pre-commit)
1517

1618
logger = logging.getLogger(__name__)
1719

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
import logging
12
from dataclasses import dataclass
23
from typing import Any, Callable, Dict, Optional, Type, TypeAlias
3-
import torch
4-
import logging
54

6-
from torch.fx import GraphModule, Node
5+
import torch
76
from torch._ops import OpOverload
8-
7+
from torch.fx import GraphModule, Node
98

109
logger = logging.getLogger(__name__)
1110

Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .maxpool1d import * # noqa: F403
21
from .einsum import * # noqa: F403
2+
from .maxpool1d import * # noqa: F403

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from typing import Dict, Tuple, Any, Optional, Sequence
1+
from typing import Any, Dict, Optional, Sequence, Tuple
2+
23
import torch
34
from torch._custom_op.impl import custom_op
45
from torch.fx.node import Argument, Target
5-
6+
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
67
from torch_tensorrt.fx.converter_registry import tensorrt_converter
78
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
89
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
910

10-
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
11-
1211

1312
@custom_op(
1413
qualname="tensorrt::einsum",

0 commit comments

Comments
 (0)