Skip to content

Commit 5ad9826

Browse files
Weiwwei6
Wei
and
wwei6
authored
Changes done internally at Facebook (#1178)
6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park <[email protected]> Add support for generic torch ops to be used in training. Co-authored-by: wwei6 <[email protected]>
1 parent 2fd564e commit 5ad9826

File tree

7 files changed

+178
-42
lines changed

7 files changed

+178
-42
lines changed

py/torch_tensorrt/fx/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
tensorrt_converter,
77
)
88
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
9-
from .input_tensor_spec import InputTensorSpec # noqa
9+
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
10+
from .lower_setting import LowerSetting # noqa
1011
from .trt_module import TRTModule # noqa

py/torch_tensorrt/fx/input_tensor_spec.py

+71-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,66 @@
1-
from typing import Iterable, List, NamedTuple, Sequence, Tuple
1+
from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple
22

33
import torch
44

55
from .types import Shape, ShapeRange
66
from .utils import get_dynamic_dims
77

88

9+
def generate_input_specs(
10+
inputs, lower_setting, additional_inputs=None, fixed_shape=False
11+
):
12+
# AIT lower setting doesn't have explicit_batch_dimension field and
13+
# we just return None.
14+
if not hasattr(lower_setting, "explicit_batch_dimension"):
15+
return None
16+
17+
if not lower_setting.explicit_batch_dimension or fixed_shape:
18+
return InputTensorSpec.from_tensors(inputs)
19+
20+
# If we don't have additional inputs, we assume the first dimension
21+
# is the dynamic batch dimension. Otherwise, we use the additional
22+
# inputs to determine the batch dimension.
23+
if additional_inputs is None:
24+
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
25+
inputs,
26+
(
27+
0,
28+
lower_setting.max_batch_size,
29+
lower_setting.max_batch_size,
30+
),
31+
lower_setting.opt_profile_replica,
32+
)
33+
else:
34+
batch_dims = []
35+
36+
for i, j in zip(inputs, additional_inputs):
37+
found_batch_dim = False
38+
39+
for idx, values in enumerate(zip(i.shape, j.shape)):
40+
if values[0] != values[1]:
41+
assert (
42+
found_batch_dim is False
43+
), f"We've already found a batch dim, {i.shape}, {j.shape}."
44+
batch_dims.append(idx)
45+
found_batch_dim = True
46+
47+
if not found_batch_dim:
48+
raise RuntimeError(
49+
f"Failed to find batch dimension because shapes are the same, {i.shape}"
50+
)
51+
52+
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
53+
inputs,
54+
(
55+
0,
56+
lower_setting.max_batch_size,
57+
lower_setting.max_batch_size,
58+
),
59+
lower_setting.opt_profile_replica,
60+
batch_dims,
61+
)
62+
63+
964
class InputTensorSpec(NamedTuple):
1065
"""
1166
This class contains the information of a input tensor.
@@ -70,6 +125,7 @@ def from_tensors_with_dynamic_batch_size(
70125
tensors: Sequence[torch.Tensor],
71126
batch_size_range: Tuple[int, int, int],
72127
opt_profile_replica: int = 1,
128+
batch_dims: Optional[List[int]] = None,
73129
) -> List["InputTensorSpec"]:
74130
"""
75131
Produce a list of InputTenosrSpec named tuples which would contain
@@ -83,20 +139,30 @@ def from_tensors_with_dynamic_batch_size(
83139
the smallest batch size allowed. The second integer indiceates
84140
the batch size that we'll optimize for. The third integer indicates
85141
the largest batch size allowed.
142+
opt_profile_replica (int): If dynamic shape is enabled, each execution
143+
context requires a different optimization profile. This arg determines
144+
how many optimization profile replicas we want to produce.
145+
batch_dims (Optional[List[int]]): The batch dim might not be the leading dim
146+
and allow user to specify the batch dims using this arg. Default we treat
147+
dim 0 as the batch dim.
86148
87149
Returns:
88150
A list of InputTensorSpec named tuples with dynamic ranges.
89151
"""
152+
if batch_dims is None:
153+
batch_dims = [0] * len(tensors)
154+
90155
input_specs = []
91-
batch_size = tensors[0].size(0)
156+
batch_size = tensors[0].size(batch_dims[0])
92157

93158
for i, tensor in enumerate(tensors):
159+
batch_dim = batch_dims[i]
94160
assert batch_size == tensor.size(
95-
0
161+
batch_dim
96162
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
97163
shape = list(tensor.shape)
98-
shape[0] = -1
99-
shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
164+
shape[batch_dim] = -1
165+
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
100166
input_specs.append(
101167
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
102168
)

py/torch_tensorrt/fx/lower.py

+9-29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses as dc
22
import logging
3-
from typing import Any, Callable, Sequence
3+
from typing import Any, Callable, Optional, Sequence
44

55
# @manual=//deeplearning/trt/python:py_tensorrt
66
import tensorrt as trt
@@ -10,15 +10,9 @@
1010
from torch.fx.passes.splitter_base import SplitResult
1111

1212
from .fx2trt import TRTInterpreter, TRTInterpreterResult
13-
from .input_tensor_spec import InputTensorSpec
1413
from .lower_setting import LowerSetting
1514
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
16-
from .passes.pass_utils import (
17-
chain_passes,
18-
decorate_method,
19-
PassFunc,
20-
validate_inference,
21-
)
15+
from .passes.pass_utils import decorate_method, PassFunc, validate_inference
2216
from .tools.timing_cache_utils import TimingCacheManager
2317
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting
2418

@@ -91,25 +85,8 @@ def create(cls, lower_setting):
9185
return LowerTrtInterpreter(lower_setting, timing_cache_manager)
9286

9387
def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
94-
input_specs_val = (
95-
self.lower_setting.input_specs
96-
if self.lower_setting.input_specs
97-
else (
98-
InputTensorSpec.from_tensors_with_dynamic_batch_size(
99-
input,
100-
(
101-
0,
102-
self.lower_setting.max_batch_size,
103-
self.lower_setting.max_batch_size,
104-
),
105-
self.lower_setting.opt_profile_replica,
106-
)
107-
if self.lower_setting.explicit_batch_dimension
108-
and self.lower_setting.dynamic_batch
109-
else InputTensorSpec.from_tensors(input)
110-
)
111-
)
112-
logger.info(f"{split_name=} {input_specs_val=}")
88+
assert self.lower_setting.input_specs, "Can't find input specs for lowering!"
89+
logger.info(f"{split_name=} {self.lower_setting.input_specs=}")
11390

11491
# Prepare algorithm selector and timing_cache for TRTInterpreter
11592
algo_selector = None
@@ -125,7 +102,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
125102

126103
interpreter = TRTInterpreter(
127104
mod,
128-
input_specs=input_specs_val,
105+
input_specs=self.lower_setting.input_specs,
129106
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
130107
explicit_precision=self.lower_setting.explicit_precision,
131108
logger_level=trt.Logger.VERBOSE
@@ -242,6 +219,7 @@ def __call__(
242219
self,
243220
module: nn.Module,
244221
inputs: Input,
222+
additional_inputs: Optional[Input] = None,
245223
) -> nn.Module:
246224
module.eval()
247225

@@ -254,7 +232,9 @@ def __call__(
254232
x.half() if x is not None and x.dtype == torch.float32 else x
255233
for x in inputs
256234
)
257-
pm = self.lower_pass_manager_builder.build_lower_pipeline(inputs)
235+
pm = self.lower_pass_manager_builder.build_lower_pipeline(
236+
inputs, additional_inputs
237+
)
258238

259239
lower_result = pm(module)
260240

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from functools import partial, wraps
2-
from typing import Any, Callable, Sequence
2+
from typing import Any, Callable, Optional, Sequence
33

44
import torch
55
from torch import nn
66
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
77
from torch.fx.passes.shape_prop import ShapeProp
8-
from torch.fx.passes.splitter_base import SplitResult
8+
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
9+
10+
from ..input_tensor_spec import generate_input_specs
911

1012
from ..lower_setting import LowerSetting
1113
from ..observer import Observer
@@ -120,13 +122,33 @@ def _split_pass(self) -> PassManager:
120122

121123
def _lower_pass(self) -> PassManager:
122124
def lower_func(split_result: SplitResult) -> nn.Module:
125+
if (
126+
hasattr(self.lower_setting, "explicit_batch_dimension")
127+
and self.lower_setting.explicit_batch_dimension
128+
and self._additional_input
129+
):
130+
additional_submodule_inputs = generate_inputs_for_submodules(
131+
split_result.split_module,
132+
self._additional_input,
133+
list(split_result.submodule_inputs.keys()),
134+
)
135+
else:
136+
additional_submodule_inputs = None
137+
123138
for submod_name, submod_inputs in split_result.submodule_inputs.items():
124139
submod = getattr(split_result.split_module, submod_name)
125140

126141
LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)
127142

128143
# Only acc submodules will be lowered.
129144
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
145+
self.lower_setting.input_specs = generate_input_specs(
146+
submod_inputs,
147+
self.lower_setting,
148+
additional_submodule_inputs[submod_name]
149+
if additional_submodule_inputs
150+
else None,
151+
)
130152
lowered_module = self._lower_func(
131153
submod, submod_inputs, self.lower_setting, submod_name
132154
)
@@ -139,8 +161,11 @@ def lower_func(split_result: SplitResult) -> nn.Module:
139161

140162
return PassManager.build_from_passlist([lower_func])
141163

142-
def build_lower_pipeline(self, input: Input) -> PassManager:
164+
def build_lower_pipeline(
165+
self, input: Input, additional_input: Optional[Input] = None
166+
) -> PassManager:
143167
self._input = input
168+
self._additional_input = additional_input
144169
passes = []
145170

146171
passes.append(self._const_fold_pass())

py/torch_tensorrt/fx/passes/pass_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def _validate_inference(pass_: PassFunc) -> PassFunc:
4141

4242
@wraps(pass_)
4343
def pass_with_validation(
44-
module: fx.GraphModule, input: Input
44+
module: fx.GraphModule,
45+
input: Input,
46+
*args,
47+
**kwargs,
4548
) -> fx.GraphModule:
4649
res0 = module(*input)
47-
processed_module = pass_(module, input)
50+
processed_module = pass_(module, input, *args, **kwargs)
4851
res1 = processed_module(*input)
4952

5053
tensor_res_0 = _collect_tensors(res0)

py/torch_tensorrt/fx/test/converters/acc_op/test_prod.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
66

77
# NOTE torch.prod will only accept one dim unlike other reduce ops which accept tuples
88

@@ -93,6 +93,26 @@ def forward(self, x):
9393
test_implicit_batch_dim=False,
9494
)
9595

96+
def test_prod_all_dims_with_dynamic_shape(
97+
self,
98+
op=torch.prod,
99+
):
100+
class Prod(torch.nn.Module):
101+
def forward(self, x):
102+
return op(x)
103+
104+
input_specs = [
105+
InputTensorSpec(
106+
shape=(-1, -1, -1, -1),
107+
dtype=torch.float32,
108+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
109+
),
110+
]
111+
112+
self.run_test_with_dynamic_shape(
113+
Prod(), input_specs, expected_ops={acc_ops.prod}
114+
)
115+
96116

97117
if __name__ == "__main__":
98118
run_tests()

py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
from torch.testing._internal.common_utils import run_tests, TestCase
7-
from torch_tensorrt.fx import InputTensorSpec
7+
from torch_tensorrt.fx import generate_input_specs, InputTensorSpec, LowerSetting
88

99

1010
class TestTRTModule(TestCase):
@@ -47,6 +47,47 @@ def test_from_tensors_with_dynamic_batch_size(self):
4747
self.assertEqual(batch_size, shape[0])
4848
self.assertSequenceEqual(tensor.shape[1:], shape[1:])
4949

50+
def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self):
51+
tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)]
52+
batch_size_range = [2, 3, 4]
53+
specs = InputTensorSpec.from_tensors_with_dynamic_batch_size(
54+
tensors, batch_size_range, batch_dims=[0, 1]
55+
)
56+
for i, spec_and_tensor in enumerate(zip(specs, tensors)):
57+
spec, tensor = spec_and_tensor
58+
self._validate_spec(spec, tensor, dynamic_dims=[i])
59+
60+
for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]):
61+
self.assertEqual(batch_size, shape[i])
62+
tensor_shape = list(tensor.shape)
63+
tensor_shape[i] = batch_size
64+
self.assertSequenceEqual(tensor_shape, shape)
65+
66+
def test_generate_input_specs(self):
67+
lower_setting = LowerSetting(
68+
explicit_batch_dimension=False, max_batch_size=256, opt_profile_replica=2
69+
)
70+
71+
# Implicit batch dim.
72+
inputs = [torch.randn(1, 2, 3)]
73+
specs = generate_input_specs(inputs, lower_setting)
74+
for spec, tensor in zip(specs, inputs):
75+
self._validate_spec(spec, tensor)
76+
77+
# Explicit batch dim without additional inputs.
78+
lower_setting.explicit_batch_dimension = True
79+
specs = generate_input_specs(inputs, lower_setting)
80+
for spec, tensor in zip(specs, inputs):
81+
self._validate_spec(spec, tensor, dynamic_dims=[0])
82+
self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica)
83+
84+
# Explicit batch dim with additional inputs.
85+
additional_inputs = [torch.randn(1, 1, 3)]
86+
specs = generate_input_specs(inputs, lower_setting, additional_inputs)
87+
for spec, tensor in zip(specs, inputs):
88+
self._validate_spec(spec, tensor, dynamic_dims=[1])
89+
self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica)
90+
5091

5192
if __name__ == "__main__":
5293
run_tests()

0 commit comments

Comments
 (0)