1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from typing import List , Optional , Sequence , Tuple
4
+ from typing import Any , List , Optional , Sequence , Tuple
5
5
6
6
import torch
7
7
import torch_tensorrt
8
8
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
9
+ from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
9
10
from torch_tensorrt .dynamo import partitioning
10
11
11
12
logger = logging .getLogger (__name__ )
12
13
13
14
15
+ def _unflatten_inputs (
16
+ flattened_inputs : Sequence [torch_tensorrt .Input ],
17
+ compiled_module : torch .fx .GraphModule ,
18
+ ) -> Tuple [Any , Any ]:
19
+ """
20
+ Process inputs using tree_unflatten and tree_map to reconstructe inputs
21
+
22
+ Args:
23
+ flattened_inputs: Flattened input tensors to process
24
+ compiled_module: The compiled GraphModule containing input specifications
25
+
26
+ Returns:
27
+ Tuple of (args, kwargs) containing reconstructed input tensors
28
+ """
29
+
30
+ def convert_input_to_cuda_tensor (input : Any ) -> torch .Tensor :
31
+ if isinstance (input , torch_tensorrt .Input ):
32
+ return input .torch_tensor .cuda ()
33
+ else :
34
+ raise RuntimeError ("Input is not a torch_tensorrt.Input" )
35
+
36
+ # Reconstruct the (args, kwargs) structure that was flattened during export
37
+ pytree_inputs = tree_unflatten (flattened_inputs , compiled_module ._in_spec )
38
+ # Apply the tensor creation to the reconstructed structure
39
+ processed_inputs = tree_map (convert_input_to_cuda_tensor , pytree_inputs )
40
+
41
+ # Since inputs were originally flattened from (args, kwargs),
42
+ # processed_inputs is now that same tuple structure
43
+ return processed_inputs [0 ], processed_inputs [1 ]
44
+
45
+
14
46
class CudaGraphsTorchTensorRTModule (torch .nn .Module ): # type: ignore[misc]
15
47
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
16
48
@@ -43,14 +75,15 @@ def warm_up(self) -> None:
43
75
Warm up is necessary to ensure that memory allocations and initializations
44
76
are not recorded in cuda graphs
45
77
"""
78
+
46
79
with torch_tensorrt .logging .errors ():
47
80
with unset_fake_temporarily ():
48
- inputs_tensor = [ spec . torch_tensor . cuda () for spec in self .inputs ]
81
+ args , kwargs = _unflatten_inputs ( self .inputs , self . compiled_module )
49
82
s = torch .cuda .Stream ()
50
83
s .wait_stream (torch .cuda .current_stream ())
51
84
with torch .cuda .stream (s ):
52
85
for _ in range (3 ):
53
- self .compiled_module (* inputs_tensor )
86
+ self .compiled_module (* args , ** kwargs )
54
87
torch .cuda .current_stream ().wait_stream (s )
55
88
56
89
def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
@@ -77,15 +110,18 @@ def __del__(self) -> None:
77
110
def set_use_output_allocator (self , enable : bool ) -> None :
78
111
self .use_output_allocator_outputs = enable
79
112
80
- def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
113
+ def forward (
114
+ self , * args : Any , ** kwargs : Any
115
+ ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
116
+ inputs , _ = tree_flatten ((args , kwargs ))
81
117
cudagraphs_enabled = torch_tensorrt .runtime .get_whole_cudagraphs_mode ()
82
118
if cudagraphs_enabled :
83
119
shape_changed = self .validate_input_shapes (inputs )
84
120
need_cudagraphs_record = shape_changed or self .is_weight_streaming_set
85
121
if need_cudagraphs_record :
86
122
if self .cudagraph :
87
123
self .cudagraph .reset ()
88
- self ._input_buffers = [None ] * len (self . inputs )
124
+ self ._input_buffers = [None ] * len (inputs )
89
125
90
126
self .is_weight_streaming_set = False
91
127
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -98,10 +134,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
98
134
for i in inputs
99
135
]
100
136
assert len (contiguous_inputs ) == len (
101
- self . inputs
102
- ), f"Wrong number of inputs, expect { len (self . inputs )} get { len (contiguous_inputs )} ."
137
+ inputs
138
+ ), f"Wrong number of inputs, expect { len (inputs )} get { len (contiguous_inputs )} ."
103
139
104
- for i , _ in enumerate (self . inputs ):
140
+ for i , _ in enumerate (inputs ):
105
141
if not contiguous_inputs [i ].is_cuda :
106
142
logger .warning (
107
143
f"Detected input[{ i } ] is not on a cuda device. "
@@ -116,8 +152,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
116
152
)
117
153
118
154
assert (
119
- contiguous_inputs [i ].dtype == self . inputs [i ].dtype
120
- ), f"Dtype mismatch for { i } th input. Expect { self . inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
155
+ contiguous_inputs [i ].dtype == inputs [i ].dtype
156
+ ), f"Dtype mismatch for { i } th input. Expect { inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
121
157
122
158
if need_cudagraphs_record :
123
159
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -126,6 +162,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
126
162
else :
127
163
self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
128
164
165
+ if need_cudagraphs_record :
166
+ # Reconstruct the original args and kwargs structure from static input buffers
167
+ # using the input specification stored during module compilation
168
+ args , kwargs = tree_unflatten (
169
+ self ._input_buffers , self .compiled_module ._in_spec
170
+ )
171
+
129
172
self ._caller_stream = torch .cuda .current_stream ()
130
173
if (
131
174
self ._engine_stream == torch .cuda .default_stream ()
@@ -139,9 +182,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
139
182
if need_cudagraphs_record :
140
183
self .cudagraph = torch .cuda .CUDAGraph ()
141
184
with torch .cuda .graph (self .cudagraph , stream = self ._engine_stream ):
142
- self ._output_buffers = self .compiled_module (
143
- * self ._input_buffers
144
- )
185
+ self ._output_buffers = self .compiled_module (* args , ** kwargs )
145
186
146
187
self .cudagraph .replay () # type: ignore
147
188
self ._caller_stream .wait_stream (self ._engine_stream )
@@ -158,4 +199,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
158
199
if self .cudagraph :
159
200
self .cudagraph .reset ()
160
201
self .cudagraph = None
161
- return self .compiled_module (* inputs )
202
+ return self .compiled_module (* args , ** kwargs )
0 commit comments