Skip to content

Commit 35c489d

Browse files
authored
fix: structured inputs for CudaGraphsTorchTensorRTModule (#3407)
1 parent 80a06ec commit 35c489d

File tree

2 files changed

+105
-50
lines changed

2 files changed

+105
-50
lines changed

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

+55-14
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,48 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import List, Optional, Sequence, Tuple
4+
from typing import Any, List, Optional, Sequence, Tuple
55

66
import torch
77
import torch_tensorrt
88
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
9+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
910
from torch_tensorrt.dynamo import partitioning
1011

1112
logger = logging.getLogger(__name__)
1213

1314

15+
def _unflatten_inputs(
16+
flattened_inputs: Sequence[torch_tensorrt.Input],
17+
compiled_module: torch.fx.GraphModule,
18+
) -> Tuple[Any, Any]:
19+
"""
20+
Process inputs using tree_unflatten and tree_map to reconstructe inputs
21+
22+
Args:
23+
flattened_inputs: Flattened input tensors to process
24+
compiled_module: The compiled GraphModule containing input specifications
25+
26+
Returns:
27+
Tuple of (args, kwargs) containing reconstructed input tensors
28+
"""
29+
30+
def convert_input_to_cuda_tensor(input: Any) -> torch.Tensor:
31+
if isinstance(input, torch_tensorrt.Input):
32+
return input.torch_tensor.cuda()
33+
else:
34+
raise RuntimeError("Input is not a torch_tensorrt.Input")
35+
36+
# Reconstruct the (args, kwargs) structure that was flattened during export
37+
pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec)
38+
# Apply the tensor creation to the reconstructed structure
39+
processed_inputs = tree_map(convert_input_to_cuda_tensor, pytree_inputs)
40+
41+
# Since inputs were originally flattened from (args, kwargs),
42+
# processed_inputs is now that same tuple structure
43+
return processed_inputs[0], processed_inputs[1]
44+
45+
1446
class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
1547
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1648
@@ -43,14 +75,15 @@ def warm_up(self) -> None:
4375
Warm up is necessary to ensure that memory allocations and initializations
4476
are not recorded in cuda graphs
4577
"""
78+
4679
with torch_tensorrt.logging.errors():
4780
with unset_fake_temporarily():
48-
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
81+
args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module)
4982
s = torch.cuda.Stream()
5083
s.wait_stream(torch.cuda.current_stream())
5184
with torch.cuda.stream(s):
5285
for _ in range(3):
53-
self.compiled_module(*inputs_tensor)
86+
self.compiled_module(*args, **kwargs)
5487
torch.cuda.current_stream().wait_stream(s)
5588

5689
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
@@ -77,15 +110,18 @@ def __del__(self) -> None:
77110
def set_use_output_allocator(self, enable: bool) -> None:
78111
self.use_output_allocator_outputs = enable
79112

80-
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
113+
def forward(
114+
self, *args: Any, **kwargs: Any
115+
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
116+
inputs, _ = tree_flatten((args, kwargs))
81117
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
82118
if cudagraphs_enabled:
83119
shape_changed = self.validate_input_shapes(inputs)
84120
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
85121
if need_cudagraphs_record:
86122
if self.cudagraph:
87123
self.cudagraph.reset()
88-
self._input_buffers = [None] * len(self.inputs)
124+
self._input_buffers = [None] * len(inputs)
89125

90126
self.is_weight_streaming_set = False
91127
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -98,10 +134,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
98134
for i in inputs
99135
]
100136
assert len(contiguous_inputs) == len(
101-
self.inputs
102-
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
137+
inputs
138+
), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}."
103139

104-
for i, _ in enumerate(self.inputs):
140+
for i, _ in enumerate(inputs):
105141
if not contiguous_inputs[i].is_cuda:
106142
logger.warning(
107143
f"Detected input[{i}] is not on a cuda device. "
@@ -116,8 +152,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
116152
)
117153

118154
assert (
119-
contiguous_inputs[i].dtype == self.inputs[i].dtype
120-
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
155+
contiguous_inputs[i].dtype == inputs[i].dtype
156+
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
121157

122158
if need_cudagraphs_record:
123159
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -126,6 +162,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
126162
else:
127163
self._input_buffers[i].copy_(contiguous_inputs[i])
128164

165+
if need_cudagraphs_record:
166+
# Reconstruct the original args and kwargs structure from static input buffers
167+
# using the input specification stored during module compilation
168+
args, kwargs = tree_unflatten(
169+
self._input_buffers, self.compiled_module._in_spec
170+
)
171+
129172
self._caller_stream = torch.cuda.current_stream()
130173
if (
131174
self._engine_stream == torch.cuda.default_stream()
@@ -139,9 +182,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
139182
if need_cudagraphs_record:
140183
self.cudagraph = torch.cuda.CUDAGraph()
141184
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
142-
self._output_buffers = self.compiled_module(
143-
*self._input_buffers
144-
)
185+
self._output_buffers = self.compiled_module(*args, **kwargs)
145186

146187
self.cudagraph.replay() # type: ignore
147188
self._caller_stream.wait_stream(self._engine_stream)
@@ -158,4 +199,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
158199
if self.cudagraph:
159200
self.cudagraph.reset()
160201
self.cudagraph = None
161-
return self.compiled_module(*inputs)
202+
return self.compiled_module(*args, **kwargs)

tests/py/dynamo/runtime/test_004_weight_streaming.py

+50-36
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch_tensorrt as torchtrt
77
from parameterized import parameterized
88
from torch.testing._internal.common_utils import TestCase, run_tests
9+
from torch_tensorrt.dynamo.utils import prepare_inputs
910

1011
INPUT_SIZE = (64, 100)
1112

@@ -290,10 +291,6 @@ def test_weight_streaming_cudagraphs(self, _, use_python_runtime):
290291
("cpp_runtime", False),
291292
]
292293
)
293-
@unittest.skipIf(
294-
os.environ.get("CI_BUILD") == "1",
295-
"Skipping test due to CI resource constraints",
296-
)
297294
def test_runtime_state_change(self, _, use_python_runtime):
298295
class SampleModel(torch.nn.Module):
299296
def __init__(self):
@@ -302,45 +299,62 @@ def __init__(self):
302299
self.layer2 = torch.nn.Linear(128, 64)
303300
self.relu = torch.nn.ReLU()
304301

305-
def forward(self, x):
302+
def forward(self, x, b=None, c=None, d=None, e=[]):
306303
out = self.layer1(x)
304+
out = out + b
305+
if c is not None:
306+
out = out * c
307307
out = self.relu((out + 2.0) * 0.05)
308+
if d is not None:
309+
out = out - d["value"] + d["value2"]
308310
out = self.layer2(out)
311+
for n in e:
312+
out += n
309313
return out
310314

311-
inputs = torchtrt.Input(
312-
min_shape=(1, 100),
313-
opt_shape=(64, 100),
314-
max_shape=(128, 100),
315-
dtype=torch.float,
316-
name="x",
317-
)
318315
model = SampleModel().eval().cuda()
319316
input_list = []
320-
input_list.append(torch.randn((8, 100)).cuda())
321-
input_list.append(torch.randn((12, 100)).cuda())
322-
input_list.append(torch.randn((12, 100)).cuda())
323-
input_list.append(torch.randn((8, 100)).cuda())
324-
input_list.append(torch.randn((8, 100)).cuda())
325-
326-
dynamic_shapes = (
327-
{
328-
0: torch.export.Dim("batch_size", min=1, max=128),
329-
},
330-
)
331-
exp_program = torch.export.export(
332-
model, (input_list[0],), dynamic_shapes=dynamic_shapes
333-
)
334-
317+
for batch_size in [8, 12, 12, 8, 8]:
318+
args = [torch.rand((batch_size, 100)).to("cuda")]
319+
kwargs = {
320+
"b": torch.rand((1, 128)).to("cuda"),
321+
"d": {
322+
"value": torch.rand(1).to("cuda"),
323+
"value2": torch.tensor(1.2).to("cuda"),
324+
},
325+
"e": [torch.rand(1).to("cuda"), torch.rand(1).to("cuda")],
326+
}
327+
input_list.append((args, kwargs))
328+
329+
kwarg_torchtrt_input = prepare_inputs(input_list[0][1])
330+
331+
compile_spec = {
332+
"arg_inputs": [
333+
torchtrt.Input(
334+
min_shape=(1, 100),
335+
opt_shape=(64, 100),
336+
max_shape=(128, 100),
337+
dtype=torch.float32,
338+
name="x",
339+
),
340+
],
341+
"kwarg_inputs": kwarg_torchtrt_input,
342+
"device": torchtrt.Device("cuda:0"),
343+
"enabled_precisions": {torch.float},
344+
"pass_through_build_failures": True,
345+
"min_block_size": 1,
346+
"ir": "dynamo",
347+
"cache_built_engines": False,
348+
"reuse_cached_engines": False,
349+
"use_explicit_typing": True,
350+
"enable_weight_streaming": True,
351+
"torch_executed_ops": {"torch.ops.aten.mul.Tensor"},
352+
"use_python_runtime": use_python_runtime,
353+
}
354+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
335355
optimized_model = torchtrt.dynamo.compile(
336356
exp_program,
337-
inputs,
338-
min_block_size=1,
339-
pass_through_build_failures=True,
340-
use_explicit_typing=True,
341-
enable_weight_streaming=True,
342-
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
343-
use_python_runtime=use_python_runtime,
357+
**compile_spec,
344358
)
345359

346360
# List of tuples representing different configurations for three features:
@@ -361,12 +375,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list):
361375
for i in range(len(input_list)):
362376
if enable_weight_streaming and i == 4:
363377
weight_streaming_ctx.device_budget = int(streamable_budget * 0.6)
364-
out_list.append(optimized_model(input_list[i]))
378+
out_list.append(optimized_model(*input_list[i][0], **input_list[i][1]))
365379
return out_list
366380

367381
ref_out_list = []
368382
for i in range(len(input_list)):
369-
ref_out_list.append(model(input_list[i]))
383+
ref_out_list.append(model(*input_list[i][0], **input_list[i][1]))
370384

371385
pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs(
372386
optimized_model

0 commit comments

Comments
 (0)