Skip to content

Commit ebaacff

Browse files
committed
fix: structured inputs for CudaGraphsTorchTensorRTModule
1 parent 0a46392 commit ebaacff

File tree

1 file changed

+61
-18
lines changed

1 file changed

+61
-18
lines changed

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

+61-18
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,48 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import List, Optional, Sequence, Tuple
4+
from typing import Any, List, Optional, Sequence, Tuple
55

66
import torch
77
import torch_tensorrt
88
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
9+
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
910
from torch_tensorrt.dynamo import partitioning
1011

1112
logger = logging.getLogger(__name__)
1213

1314

15+
def _unflatten_inputs(
16+
flattened_inputs: Sequence[torch_tensorrt.Input],
17+
compiled_module: torch.fx.GraphModule,
18+
) -> Tuple[Any, Any]:
19+
"""
20+
Process inputs using tree_unflatten and tree_map to reconstructe inputs
21+
22+
Args:
23+
flattened_inputs: Flattened input tensors to process
24+
compiled_module: The compiled GraphModule containing input specifications
25+
26+
Returns:
27+
Tuple of (args, kwargs) containing reconstructed input tensors
28+
"""
29+
30+
def create_example_tensor(input: Any) -> torch.Tensor:
31+
if isinstance(input, torch_tensorrt.Input):
32+
return input.torch_tensor.cuda()
33+
else:
34+
raise RuntimeError("Input is not a torch_tensorrt.Input")
35+
36+
# Reconstruct the (args, kwargs) structure that was flattened during export
37+
pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec)
38+
# Apply the tensor creation to the reconstructed structure
39+
processed_inputs = tree_map(create_example_tensor, pytree_inputs)
40+
41+
# Since inputs were originally flattened from (args, kwargs),
42+
# processed_inputs is now that same tuple structure
43+
return processed_inputs[0], processed_inputs[1]
44+
45+
1446
class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
1547
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1648
@@ -42,14 +74,15 @@ def warm_up(self) -> None:
4274
Warm up is necessary to ensure that memory allocations and initializations
4375
are not recorded in cuda graphs
4476
"""
77+
4578
with torch_tensorrt.logging.errors():
4679
with unset_fake_temporarily():
47-
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
80+
args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module)
4881
s = torch.cuda.Stream()
4982
s.wait_stream(torch.cuda.current_stream())
5083
with torch.cuda.stream(s):
5184
for _ in range(3):
52-
self.compiled_module(*inputs_tensor)
85+
self.compiled_module(*args, **kwargs)
5386
torch.cuda.current_stream().wait_stream(s)
5487

5588
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
@@ -73,7 +106,12 @@ def __del__(self) -> None:
73106
if self.cudagraph:
74107
self.cudagraph.reset()
75108

76-
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
109+
def forward(
110+
self, *args: Any, **kwargs: Any
111+
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
112+
# pytree_inputs = tree_unflatten(self.inputs, self.compiled_module._in_spec)
113+
114+
inputs, spec = tree_flatten((args, kwargs))
77115
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
78116
if cudagraphs_enabled:
79117
shape_changed = self.validate_input_shapes(inputs)
@@ -94,10 +132,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
94132
for i in inputs
95133
]
96134
assert len(contiguous_inputs) == len(
97-
self.inputs
98-
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."
135+
inputs
136+
), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}."
99137

100-
for i, _ in enumerate(self.inputs):
138+
for i, _ in enumerate(inputs):
101139
if not contiguous_inputs[i].is_cuda:
102140
logger.warning(
103141
f"Detected input[{i}] is not on a cuda device. "
@@ -112,15 +150,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
112150
)
113151

114152
assert (
115-
contiguous_inputs[i].dtype == self.inputs[i].dtype
116-
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
153+
contiguous_inputs[i].dtype == inputs[i].dtype
154+
), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
155+
156+
if need_cudagraphs_record:
157+
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
158+
# Clone is required to avoid re-using user-provided GPU memory
159+
self._input_buffers[i] = contiguous_inputs[i].clone()
160+
else:
161+
self._input_buffers[i].copy_(contiguous_inputs[i])
117162

118163
if need_cudagraphs_record:
119-
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
120-
# Clone is required to avoid re-using user-provided GPU memory
121-
self._input_buffers[i] = contiguous_inputs[i].clone()
122-
else:
123-
self._input_buffers[i].copy_(contiguous_inputs[i])
164+
# Reconstruct the original args and kwargs structure from static input buffers
165+
# using the input specification stored during module compilation
166+
args, kwargs = tree_unflatten(
167+
self._input_buffers, self.compiled_module._in_spec
168+
)
124169

125170
self._caller_stream = torch.cuda.current_stream()
126171
if (
@@ -135,9 +180,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
135180
if need_cudagraphs_record:
136181
self.cudagraph = torch.cuda.CUDAGraph()
137182
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
138-
self._output_buffers = self.compiled_module(
139-
*self._input_buffers
140-
)
183+
self._output_buffers = self.compiled_module(*args, **kwargs)
141184

142185
self.cudagraph.replay() # type: ignore
143186
self._caller_stream.wait_stream(self._engine_stream)
@@ -154,4 +197,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
154197
if self.cudagraph:
155198
self.cudagraph.reset()
156199
self.cudagraph = None
157-
return self.compiled_module(*inputs)
200+
return self.compiled_module(*args, **kwargs)

0 commit comments

Comments
 (0)