Skip to content

Commit bd701af

Browse files
committed
feat(//py)!: Implementing top level python api changes to reflect new
Input type and enabled_precisions set BREAKING CHANGE: This commit introduces the next iteration of the Python TRTorch API. Starting in TRTorch v0.5.0 support for the "input_shapes" and "op_precision" compile spec keys will be removed. Users should port forward to using the "inputs" key which expects a list of trtorch.Input objects and the "enabled_precisions" key which expects a set of data type specifying enums. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent f00de94 commit bd701af

10 files changed

+391
-151
lines changed

py/setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def run(self):
181181
include_dirs=[
182182
dir_path + "trtorch/csrc",
183183
dir_path + "/../",
184-
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
184+
dir_path + "/../bazel-trtorch-testing/external/tensorrt/include",
185185
],
186186
extra_compile_args=[
187187
"-Wno-deprecated",

py/trtorch/Input.py

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from enum import Enum
2+
from typing import List, Dict, Any
3+
4+
import torch
5+
6+
from trtorch import _types
7+
import trtorch._C
8+
9+
class Input(object):
10+
"""
11+
Defines an input to a module in terms of expected shape, data type and tensor format.
12+
13+
Attributes:
14+
shape_mode (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
15+
shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape.
16+
Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
17+
``{
18+
"min_shape": Tuple,
19+
"opt_shape": Tuple,
20+
"max_shape": Tuple
21+
}``
22+
dtype (trtorch.dtype): The expected data type of the input tensor (default: trtorch.dtype.float32)
23+
format (trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
24+
"""
25+
26+
class _ShapeMode(Enum):
27+
STATIC = 0
28+
DYNAMIC = 1
29+
30+
shape_mode = None
31+
shape = None
32+
dtype = _types.dtype.float32
33+
format = _types.TensorFormat.contiguous
34+
35+
def __init__(self, *args, **kwargs):
36+
""" __init__ Method for trtorch.Input
37+
38+
Input accepts one of a few construction patterns
39+
40+
Args:
41+
shape (Tuple or List, optional): Static shape of input tensor
42+
43+
Keyword Arguments:
44+
shape (Tuple or List, optional): Static shape of input tensor
45+
min_shape (Tuple or List, optional): Min size of input tensor's shape range
46+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
47+
opt_shape (Tuple or List, optional): Opt size of input tensor's shape range
48+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
49+
max_shape (Tuple or List, optional): Max size of input tensor's shape range
50+
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
51+
dtype (torch.dtype or trtorch.dtype): Expected data type for input tensor (default: trtorch.dtype.float32)
52+
format (torch.memory_format or trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
53+
54+
Examples:
55+
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
56+
- Input(shape=(1,3,32,32), dtype=trtorch.dtype.int32, format=trtorch.TensorFormat.NCHW)
57+
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=trtorch.dtype.float32, format=trtorch.TensorFormat.NCHW
58+
"""
59+
if len(args) == 1:
60+
if not Input._supported_input_size_type(args[0]):
61+
raise TypeError(
62+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
63+
+ str(type(args[0])))
64+
if any(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
65+
raise ValueError("Found that both shape (as a positional argument), and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
66+
self.shape = tuple(args[0])
67+
self.shape_mode = Input._ShapeMode.STATIC
68+
69+
elif len(args) == 0:
70+
if not ("shape" in kwargs) and not(all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"])):
71+
raise ValueError("Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined")
72+
elif ("shape" in kwargs) and all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]):
73+
raise ValueError("Found that both shape, and one or more of min_shape, opt_shape, max_shape were specified\nclass Input expects that only either shape or all three of min_shape, opt_shape, max_shape are defined")
74+
75+
if "shape" in kwargs:
76+
if not Input._supported_input_size_type(kwargs["shape"]):
77+
raise TypeError(
78+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
79+
+ str(type(kwargs["shape"])))
80+
self.shape = tuple(kwargs["shape"])
81+
self.shape_mode = Input._ShapeMode.STATIC
82+
else:
83+
if not Input._supported_input_size_type(kwargs["min_shape"]):
84+
raise TypeError(
85+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
86+
+ str(type(kwargs["min_shape"])) + " for min_shape")
87+
if not Input._supported_input_size_type(kwargs["opt_shape"]):
88+
raise TypeError(
89+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
90+
+ str(type(kwargs["opt_shape"])) + " for opt_shape")
91+
if not Input._supported_input_size_type(kwargs["max_shape"]):
92+
raise TypeError(
93+
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
94+
+ str(type(kwargs["max_shape"])) + " for max_shape")
95+
96+
self.shape = {
97+
"min_shape": tuple(kwargs["min_shape"]),
98+
"opt_shape": tuple(kwargs["opt_shape"]),
99+
"max_shape": tuple(kwargs["max_shape"])
100+
}
101+
self.shape_mode = Input._ShapeMode.DYNAMIC
102+
103+
if "dtype" in kwargs:
104+
self.dtype = Input._parse_dtype(kwargs["dtype"])
105+
106+
if "format" in kwargs:
107+
self.format = Input._parse_format(kwargs["format"])
108+
109+
else:
110+
raise ValueError("Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format(len(args)))
111+
112+
def __str__(self) -> str:
113+
if self.shape_mode == Input._ShapeMode.STATIC:
114+
return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format))
115+
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
116+
return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype), str(self.format))
117+
else:
118+
raise RuntimeError("Unknown input shape mode")
119+
120+
def _to_internal(self) -> trtorch._C.Input:
121+
internal_in = trtorch._C.Input()
122+
if self.shape_mode == Input._ShapeMode.DYNAMIC:
123+
internal_in.min = self.shape["min_shape"]
124+
internal_in.opt = self.shape["opt_shape"]
125+
internal_in.max = self.shape["max_shape"]
126+
internal_in.input_is_dynamic = True
127+
else:
128+
internal_in.opt = self.shape
129+
internal_in.input_is_dynamic = False
130+
internal_in.dtype = self.dtype
131+
internal_in.format = self.format
132+
return internal_in
133+
134+
@staticmethod
135+
def _supported_input_size_type(input_size: Any) -> bool:
136+
if isinstance(input_size, torch.Size):
137+
return True
138+
elif isinstance(input_size, tuple):
139+
return True
140+
elif isinstance(input_size, list):
141+
return True
142+
else:
143+
return False
144+
145+
@staticmethod
146+
def _parse_dtype(dtype: Any) -> _types.dtype:
147+
if isinstance(dtype, torch.dtype):
148+
if dtype == torch.int32:
149+
return _types.dtype.int32
150+
elif dtype == torch.half:
151+
return _types.dtype.half
152+
elif dtype == torch.float:
153+
return _types.dtype.float
154+
elif dtype == torch.bool:
155+
return _types.dtype.bool
156+
else:
157+
raise TypeError("Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: " +
158+
str(dtype))
159+
160+
elif isinstance(dtype, _types.DataTypes):
161+
return dtype
162+
163+
else:
164+
raise TypeError("Input data type needs to be specified with a torch.dtype or a trtorch.dtype, got: " +
165+
str(type(dtype)))
166+
167+
@staticmethod
168+
def _parse_format(format: Any) -> _types.TensorFormat:
169+
if isinstance(format, torch.memory_format):
170+
if format == torch.contiguous_format:
171+
return _types.TensorFormat.contiguous
172+
elif format == torch.channels_last:
173+
return _types.TensorFormat.channel_last
174+
else:
175+
raise ValueError("Provided an unsupported tensor format (support: NHCW/contiguous_format, NHWC/channel_last)")
176+
177+
elif isinstance(format, _types.TensorFormat):
178+
return format
179+
180+
else:
181+
raise TypeError("Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")

py/trtorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from trtorch import ptq
1414
from trtorch._types import *
1515
from trtorch import logging
16+
from trtorch.Input import Input
1617

1718

1819
def _register_with_torch():

py/trtorch/_compile_spec.py

+63-63
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
from typing import List, Dict, Any
1+
from typing import List, Dict, Any, Set
22
import torch
33
import trtorch._C
44
from trtorch import _types
5+
from trtorch.Input import Input
6+
7+
import warnings
58

69

710
def _supported_input_size_type(input_size: Any) -> bool:
@@ -26,36 +29,23 @@ def _parse_input_ranges(input_sizes: List) -> List:
2629
for i in input_sizes:
2730
if isinstance(i, dict):
2831
if all(k in i for k in ["min", "opt", "min"]):
29-
in_range = trtorch._C.InputRange()
30-
in_range.min = i["min"]
31-
in_range.opt = i["opt"]
32-
in_range.max = i["max"]
33-
parsed_input_sizes.append(in_range)
32+
parsed_input_sizes.append(Input(min_shape=i["min"], opt_shape=i["opt"], max_shape=i["max"])._to_internal())
3433

3534
elif "opt" in i:
36-
in_range = trtorch._C.InputRange()
37-
in_range.min = i["opt"]
38-
in_range.opt = i["opt"]
39-
in_range.max = i["opt"]
40-
parsed_input_sizes.append(in_range)
35+
parsed_input_sizes.append(Input(shape=i["opt"])._to_internal())
4136

4237
else:
4338
raise KeyError(
4439
"An input size must either be a static size or a range of three sizes (min, opt, max) as Dict")
4540

4641
elif isinstance(i, list):
47-
in_range = trtorch._C.InputRange()
48-
in_range.min = i
49-
in_range.opt = i
50-
in_range.max = i
51-
parsed_input_sizes.append(in_range)
42+
parsed_input_sizes.append(Input(shape=i)._to_internal())
5243

5344
elif isinstance(i, tuple):
54-
in_range = trtorch._C.InputRange()
55-
in_range.min = list(i)
56-
in_range.opt = list(i)
57-
in_range.max = list(i)
58-
parsed_input_sizes.append(in_range)
45+
parsed_input_sizes.append(Input(shape=i)._to_internal())
46+
47+
elif isinstance(i, torch.Size):
48+
parsed_input_sizes.append(Input(shape=i)._to_internal())
5949

6050
return parsed_input_sizes
6151

@@ -80,6 +70,15 @@ def _parse_op_precision(precision: Any) -> _types.dtype:
8070
str(type(precision)))
8171

8272

73+
def _parse_enabled_precisions(precisions: Any) -> Set:
74+
parsed_precisions = set()
75+
if any([isinstance(precisions, type) for type in [list, tuple, set]]):
76+
for p in precisions:
77+
parsed_precisions.add(_parse_op_precision(p))
78+
else:
79+
parsed_precisions.add(_parse_op_precision(precisions))
80+
return parsed_precisions
81+
8382
def _parse_device_type(device: Any) -> _types.DeviceType:
8483
if isinstance(device, torch.device):
8584
if device.type == 'cuda':
@@ -140,39 +139,36 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
140139

141140
return info
142141

143-
def _parse_input_dtypes(input_dtypes: List) -> List:
144-
parsed_input_dtypes = []
145-
for dtype in input_dtypes:
146-
if isinstance(dtype, torch.dtype):
147-
if dtype == torch.int8:
148-
parsed_input_dtypes.append(_types.dtype.int8)
149-
elif dtype == torch.half:
150-
parsed_input_dtypes.append(_types.dtype.half)
151-
elif dtype == torch.float:
152-
parsed_input_dtypes.append(_types.dtype.float)
153-
elif dtype == torch.int32:
154-
parsed_input_dtypes.append(_types.dtype.int32)
155-
elif dtype == torch.bool:
156-
parsed_input_dtypes.append(_types.dtype.bool)
157-
else:
158-
raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
159-
160-
return parsed_input_dtypes
161-
162142
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
163143
info = trtorch._C.CompileSpec()
164-
if "input_shapes" not in compile_spec:
144+
if "input_shapes" not in compile_spec and "inputs" not in compile_spec:
165145
raise KeyError(
166-
"Input shapes for inputs are required as a List, provided as either a static sizes or a range of three sizes (min, opt, max) as Dict"
146+
"Module input definitions are requried to compile module. Provide a list of trtorch.Input keyed to \"inputs\" in the compile spec"
167147
)
168148

169-
info.input_ranges = _parse_input_ranges(compile_spec["input_shapes"])
149+
if "input_shapes" in compile_spec and "inputs" in compile_spec:
150+
raise KeyError(
151+
"Found both key \"input_shapes\", and \"inputs\" in compile spec, please port forward to using only \"inputs\""
152+
)
153+
154+
if "input_shapes" in compile_spec:
155+
warnings.warn("Key \"input_shapes\" is deprecated in favor of \"inputs\". Support for \"input_shapes\" will be removed in TRTorch v0.5.0", DeprecationWarning)
156+
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])
157+
158+
if "inputs" in compile_spec:
159+
info.inputs = [ i._to_internal() for i in compile_spec["inputs"] ]
160+
161+
if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
162+
raise KeyError(
163+
"Found both key \"op_precision\", and \"enabled_precisions\" in compile spec, please port forward to using only \"enabled_precisions\""
164+
)
170165

171166
if "op_precision" in compile_spec:
172-
info.op_precision = _parse_op_precision(compile_spec["op_precision"])
167+
warnings.warn("Key \"op_precision\" is being deprecated in favor of \"enabled_precision\" which expects a set of precisions to be enabled during compilation (FP32 will always be enabled), Support for \"op_precision\" will be removed in TRTorch v0.5.0", DeprecationWarning)
168+
info.enabled_precisions = _parse_enabled_precisions(compile_spec["op_precision"])
173169

174-
if "input_dtypes" in compile_spec:
175-
info.input_dtypes = _parse_input_dtypes(compile_spec["input_dtypes"])
170+
if "enabled_precisions" in compile_spec:
171+
info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precision"])
176172

177173
if "calibrator" in compile_spec:
178174
info.ptq_calibrator = compile_spec["calibrator"]
@@ -233,7 +229,8 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
233229
Args:
234230
compile_spec (dict): Compilation settings including operating precision, target device, etc.
235231
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
236-
to the graph. All other keys are optional. Entries for each method to be compiled.
232+
to the graph as well as expect types and formats for those inputs. All other keys are optional.
233+
Entries for each method to be compiled.
237234
238235
Note: Partial compilation of TorchScript modules is not supported through the PyTorch TensorRT backend
239236
If you need this feature, use trtorch.compile to compile your module. Usage of the resulting module is
@@ -243,13 +240,15 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
243240
244241
CompileSpec = {
245242
"forward" : trtorch.TensorRTCompileSpec({
246-
"input_shapes": [
247-
(1, 3, 224, 224), # Static input shape for input #1
248-
{
249-
"min": (1, 3, 224, 224),
250-
"opt": (1, 3, 512, 512),
251-
"max": (1, 3, 1024, 1024)
252-
} # Dynamic input shape for input #2
243+
"inputs": [
244+
trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1
245+
trtorch.Input(
246+
min_shape=1, 3, 224, 224),
247+
opt_shape=(1, 3, 512, 512),
248+
max_shape=(1, 3, 1024, 1024),
249+
dtype=torch.int32
250+
format=torch.channel_last
251+
) # Dynamic input shape for input #2
253252
],
254253
"device": {
255254
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
@@ -284,12 +283,15 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
284283

285284
backend_spec = torch.classes.tensorrt.CompileSpec()
286285

287-
for i in parsed_spec.input_ranges:
288-
ir = torch.classes.tensorrt._InputRange()
289-
ir._set_min(i.min)
290-
ir._set_opt(i.opt)
291-
ir._set_max(i.max)
292-
backend_spec._append_input_range(ir)
286+
for i in parsed_spec.inputs:
287+
clone = torch.classes.tensorrt._Input()
288+
clone._set_min(i.min)
289+
clone._set_opt(i.opt)
290+
clone._set_max(i.max)
291+
clone._set_dtype(i.dtype)
292+
clone._set_format(i.format)
293+
clone._set_input_is_dynamic(i.input_is_dynamic)
294+
backend_spec._append_input(clone)
293295

294296
d = torch.classes.tensorrt._Device()
295297
d._set_device_type(int(parsed_spec.device.device_type))
@@ -309,9 +311,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
309311

310312
backend_spec._set_device(d)
311313
backend_spec._set_torch_fallback(torch_fallback)
312-
backend_spec._set_op_precision(int(parsed_spec.op_precision))
313-
for dtype in parsed_spec.input_dtypes:
314-
backend_spec._append_input_dtypes(int64_t(dtype))
314+
backend_spec._set_precisions([int(i) for i in parsed_spec.enabled_precisions])
315315

316316
backend_spec._set_disable_tf32(parsed_spec.disable_tf32)
317317
backend_spec._set_refit(parsed_spec.refit)

0 commit comments

Comments
 (0)