Skip to content

Commit 81d6bcc

Browse files
authored
Merge pull request #2189 from pytorch/py38_compatibility
Py38 compatibility
2 parents b3089bf + f53a823 commit 81d6bcc

15 files changed

+47
-15
lines changed

py/torch_tensorrt/_Input.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
24
from typing import Any, Dict, List, Optional, Sequence, Tuple
35

@@ -32,11 +34,11 @@ class _ShapeMode(Enum):
3234
shape: Optional[
3335
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
3436
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
35-
dtype: _enums.dtype = ( # type: ignore[name-defined]
37+
dtype: _enums.dtype = (
3638
_enums.dtype.unknown
3739
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
3840
_explicit_set_dtype: bool = False
39-
format: _enums.TensorFormat = ( # type: ignore[name-defined]
41+
format: _enums.TensorFormat = (
4042
_enums.TensorFormat.contiguous
4143
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
4244

@@ -208,7 +210,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
208210
return False
209211

210212
@staticmethod
211-
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
213+
def _parse_dtype(dtype: Any) -> _enums.dtype:
212214
if isinstance(dtype, torch.dtype):
213215
if dtype == torch.long:
214216
return _enums.dtype.long
@@ -236,7 +238,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
236238
)
237239

238240
@staticmethod
239-
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
241+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
240242
if dtype == _enums.dtype.long:
241243
return torch.long
242244
elif dtype == _enums.dtype.int32:
@@ -255,7 +257,7 @@ def is_trt_dtype(self) -> bool:
255257
return bool(self.dtype != _enums.dtype.long)
256258

257259
@staticmethod
258-
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
260+
def _parse_format(format: Any) -> _enums.TensorFormat:
259261
if isinstance(format, torch.memory_format):
260262
if format == torch.contiguous_format:
261263
return _enums.TensorFormat.contiguous

py/torch_tensorrt/_compile.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
2-
from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard
4+
from typing import Any, Callable, List, Optional, Sequence, Set
35

46
import torch
57
import torch.fx
@@ -12,6 +14,7 @@
1214
from torch_tensorrt.fx.lower import compile as fx_compile
1315
from torch_tensorrt.fx.utils import LowerPrecision
1416
from torch_tensorrt.ts._compiler import compile as torchscript_compile
17+
from typing_extensions import TypeGuard
1518

1619

1720
def _non_fx_input_interface(

py/torch_tensorrt/dynamo/aten_tracer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import copy
24
import sys
35
from contextlib import contextmanager
4-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
6+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
57

68
import torch
79
import torch._dynamo as torchdynamo
@@ -22,7 +24,7 @@
2224
)
2325
from typing_extensions import TypeAlias
2426

25-
Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"]
27+
Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]
2628

2729

2830
class DynamoConfig:

py/torch_tensorrt/dynamo/backend/backends.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from functools import partial
35
from typing import Any, Callable, Sequence

py/torch_tensorrt/dynamo/compile.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import collections.abc
24
import logging
35
from typing import Any, List, Optional, Set, Tuple

py/torch_tensorrt/dynamo/conversion/conversion.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import io
24
from typing import Sequence
35

py/torch_tensorrt/dynamo/conversion/converter_registry.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from dataclasses import dataclass, field
35
from enum import Enum, auto
@@ -28,7 +30,7 @@
2830
Dict[str, Argument],
2931
str,
3032
],
31-
TRTTensor | Sequence[TRTTensor],
33+
Union[TRTTensor, Sequence[TRTTensor]],
3234
]
3335

3436

py/torch_tensorrt/dynamo/conversion/impl/shape.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import List, Optional, Tuple
24

35
import numpy as np

py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Optional, Sequence, Set
24

35
import torch

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from __future__ import annotations
2+
13
import logging
24
from dataclasses import dataclass
3-
from typing import Any, Callable, Dict, Optional, Type, TypeAlias
5+
from typing import Any, Callable, Dict, Optional, Type
46

57
import torch
68
from torch._ops import OpOverload
79
from torch.fx import GraphModule, Node
10+
from typing_extensions import TypeAlias
811

912
logger = logging.getLogger(__name__)
1013

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any, Dict, List, Optional, Sequence, Tuple
24

35
import torch

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from typing import Any, List, Optional, Tuple
35

py/torch_tensorrt/dynamo/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import logging
24
from dataclasses import fields, replace
35
from typing import Any, Callable, Dict, Optional, Sequence

py/torch_tensorrt/ts/_compile_spec.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from copy import deepcopy
24
from typing import Any, Dict, List, Optional, Set
35

@@ -39,7 +41,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
3941
)
4042

4143

42-
def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-defined]
44+
def _parse_op_precision(precision: Any) -> _enums.dtype:
4345
if isinstance(precision, torch.dtype):
4446
if precision == torch.int8:
4547
return _enums.dtype.int8
@@ -63,7 +65,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-de
6365
)
6466

6567

66-
def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ignore[name-defined]
68+
def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]:
6769
parsed_precisions = set()
6870
if any(isinstance(precisions, type) for type in [list, tuple, set]):
6971
for p in precisions:
@@ -73,7 +75,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ig
7375
return parsed_precisions
7476

7577

76-
def _parse_device_type(device: Any) -> _enums.DeviceType: # type: ignore[name-defined]
78+
def _parse_device_type(device: Any) -> _enums.DeviceType:
7779
if isinstance(device, torch.device):
7880
if device.type == "cuda":
7981
return _C.DeviceType.gpu
@@ -346,10 +348,10 @@ def TensorRTCompileSpec(
346348
device: torch.device | Device = Device._current_device(),
347349
disable_tf32: bool = False,
348350
sparse_weights: bool = False,
349-
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, # type: ignore[name-defined]
351+
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
350352
refit: bool = False,
351353
debug: bool = False,
352-
capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined]
354+
capability: _enums.EngineCapability = _enums.EngineCapability.default,
353355
num_avg_timing_iters: int = 1,
354356
workspace_size: int = 0,
355357
dla_sram_size: int = 1048576,

py/torch_tensorrt/ts/_compiler.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any, List, Optional, Sequence, Set, Tuple
24

35
import torch

0 commit comments

Comments
 (0)