1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from typing import List , Optional , Sequence , Tuple
4
+ from typing import Any , List , Optional , Sequence , Tuple
5
5
6
6
import torch
7
7
import torch_tensorrt
8
8
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
9
+ from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
9
10
from torch_tensorrt .dynamo import partitioning
10
11
11
12
logger = logging .getLogger (__name__ )
12
13
13
14
15
+ def _unflatten_inputs (
16
+ flattened_inputs : Sequence [torch_tensorrt .Input ],
17
+ compiled_module : torch .fx .GraphModule ,
18
+ ) -> Tuple [Any , Any ]:
19
+ """
20
+ Process inputs using tree_unflatten and tree_map to reconstructe inputs
21
+
22
+ Args:
23
+ flattened_inputs: Flattened input tensors to process
24
+ compiled_module: The compiled GraphModule containing input specifications
25
+
26
+ Returns:
27
+ Tuple of (args, kwargs) containing reconstructed input tensors
28
+ """
29
+
30
+ def create_example_tensor (input : Any ) -> torch .Tensor :
31
+ if isinstance (input , torch_tensorrt .Input ):
32
+ return input .torch_tensor .cuda ()
33
+ else :
34
+ raise RuntimeError ("Input is not a torch_tensorrt.Input" )
35
+
36
+ # Reconstruct the (args, kwargs) structure that was flattened during export
37
+ pytree_inputs = tree_unflatten (flattened_inputs , compiled_module ._in_spec )
38
+ # Apply the tensor creation to the reconstructed structure
39
+ processed_inputs = tree_map (create_example_tensor , pytree_inputs )
40
+
41
+ # Since inputs were originally flattened from (args, kwargs),
42
+ # processed_inputs is now that same tuple structure
43
+ return processed_inputs [0 ], processed_inputs [1 ]
44
+
45
+
14
46
class CudaGraphsTorchTensorRTModule (torch .nn .Module ): # type: ignore[misc]
15
47
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
16
48
@@ -42,14 +74,15 @@ def warm_up(self) -> None:
42
74
Warm up is necessary to ensure that memory allocations and initializations
43
75
are not recorded in cuda graphs
44
76
"""
77
+
45
78
with torch_tensorrt .logging .errors ():
46
79
with unset_fake_temporarily ():
47
- inputs_tensor = [ spec . torch_tensor . cuda () for spec in self .inputs ]
80
+ args , kwargs = _unflatten_inputs ( self .inputs , self . compiled_module )
48
81
s = torch .cuda .Stream ()
49
82
s .wait_stream (torch .cuda .current_stream ())
50
83
with torch .cuda .stream (s ):
51
84
for _ in range (3 ):
52
- self .compiled_module (* inputs_tensor )
85
+ self .compiled_module (* args , ** kwargs )
53
86
torch .cuda .current_stream ().wait_stream (s )
54
87
55
88
def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
@@ -73,7 +106,12 @@ def __del__(self) -> None:
73
106
if self .cudagraph :
74
107
self .cudagraph .reset ()
75
108
76
- def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
109
+ def forward (
110
+ self , * args : Any , ** kwargs : Any
111
+ ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
112
+ # pytree_inputs = tree_unflatten(self.inputs, self.compiled_module._in_spec)
113
+
114
+ inputs , spec = tree_flatten ((args , kwargs ))
77
115
cudagraphs_enabled = torch_tensorrt .runtime .get_whole_cudagraphs_mode ()
78
116
if cudagraphs_enabled :
79
117
shape_changed = self .validate_input_shapes (inputs )
@@ -94,10 +132,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
94
132
for i in inputs
95
133
]
96
134
assert len (contiguous_inputs ) == len (
97
- self . inputs
98
- ), f"Wrong number of inputs, expect { len (self . inputs )} get { len (contiguous_inputs )} ."
135
+ inputs
136
+ ), f"Wrong number of inputs, expect { len (inputs )} get { len (contiguous_inputs )} ."
99
137
100
- for i , _ in enumerate (self . inputs ):
138
+ for i , _ in enumerate (inputs ):
101
139
if not contiguous_inputs [i ].is_cuda :
102
140
logger .warning (
103
141
f"Detected input[{ i } ] is not on a cuda device. "
@@ -112,15 +150,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
112
150
)
113
151
114
152
assert (
115
- contiguous_inputs [i ].dtype == self .inputs [i ].dtype
116
- ), f"Dtype mismatch for { i } th input. Expect { self .inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
153
+ contiguous_inputs [i ].dtype == inputs [i ].dtype
154
+ ), f"Dtype mismatch for { i } th input. Expect { inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
155
+
156
+ if need_cudagraphs_record :
157
+ # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
158
+ # Clone is required to avoid re-using user-provided GPU memory
159
+ self ._input_buffers [i ] = contiguous_inputs [i ].clone ()
160
+ else :
161
+ self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
117
162
118
163
if need_cudagraphs_record :
119
- # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
120
- # Clone is required to avoid re-using user-provided GPU memory
121
- self . _input_buffers [ i ] = contiguous_inputs [ i ]. clone ()
122
- else :
123
- self . _input_buffers [ i ]. copy_ ( contiguous_inputs [ i ] )
164
+ # Reconstruct the original args and kwargs structure from static input buffers
165
+ # using the input specification stored during module compilation
166
+ args , kwargs = tree_unflatten (
167
+ self . _input_buffers , self . compiled_module . _in_spec
168
+ )
124
169
125
170
self ._caller_stream = torch .cuda .current_stream ()
126
171
if (
@@ -135,9 +180,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
135
180
if need_cudagraphs_record :
136
181
self .cudagraph = torch .cuda .CUDAGraph ()
137
182
with torch .cuda .graph (self .cudagraph , stream = self ._engine_stream ):
138
- self ._output_buffers = self .compiled_module (
139
- * self ._input_buffers
140
- )
183
+ self ._output_buffers = self .compiled_module (* args , ** kwargs )
141
184
142
185
self .cudagraph .replay () # type: ignore
143
186
self ._caller_stream .wait_stream (self ._engine_stream )
@@ -154,4 +197,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
154
197
if self .cudagraph :
155
198
self .cudagraph .reset ()
156
199
self .cudagraph = None
157
- return self .compiled_module (* inputs )
200
+ return self .compiled_module (* args , ** kwargs )
0 commit comments