@@ -92,8 +92,8 @@ def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
92
92
93
93
94
94
class TestPrimBroadcastFusion (TestCase ):
95
- def test_input_as_output (self ):
96
- class InputAsOutput (torch .nn .Module ):
95
+ def test_broadcast_fusion (self ):
96
+ class BroadcastFusion (torch .nn .Module ):
97
97
def forward (self , x ):
98
98
return torch .var_mean (x , keepdim = True )[1 ]
99
99
@@ -104,7 +104,7 @@ def forward(self, x):
104
104
).cuda (),
105
105
]
106
106
107
- fx_graph = torch .fx .symbolic_trace (InputAsOutput ())
107
+ fx_graph = torch .fx .symbolic_trace (BroadcastFusion ())
108
108
expected_ops = {torch .ops .aten .sum .dim_IntList }
109
109
unexpected_ops = {torch .ops .aten .var .default , torch .ops .prims .var .default }
110
110
@@ -151,7 +151,118 @@ def forward(self, x):
151
151
max_diff ,
152
152
0 ,
153
153
DECIMALS_OF_AGREEMENT ,
154
- msg = f"InputAsOutput TRT outputs don't match with the original model." ,
154
+ msg = f"BroadcastFusion TRT outputs don't match with the original model." ,
155
+ )
156
+ torch ._dynamo .reset ()
157
+
158
+
159
+ class TestLowerEfficientAttention (TestCase ):
160
+ def test_lower_efficient_attention (self ):
161
+ class EfficientAttention (torch .nn .Module ):
162
+ def forward (self , q , k , v ):
163
+ attn = torch .ops .aten ._scaled_dot_product_efficient_attention .default (
164
+ q , k , v , None , False
165
+ )
166
+ return attn [0 ]
167
+
168
+ inputs = [
169
+ torch .rand (8 , 4 , 5 , 4 ).cuda (),
170
+ torch .rand (8 , 4 , 2 , 4 ).cuda (),
171
+ torch .rand (8 , 4 , 2 , 4 ).cuda (),
172
+ ]
173
+
174
+ fx_graph = torch .fx .symbolic_trace (EfficientAttention ())
175
+ expected_ops = {torch .nn .functional .scaled_dot_product_attention }
176
+ unexpected_ops = {
177
+ torch .ops .aten ._scaled_dot_product_efficient_attention .default
178
+ }
179
+
180
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
181
+ fx_graph ,
182
+ inputs ,
183
+ expected_ops = expected_ops ,
184
+ unexpected_ops = unexpected_ops ,
185
+ min_block_size = 1 ,
186
+ )
187
+
188
+ self .assertEquals (
189
+ len (unexpected_ops_seen ),
190
+ 0 ,
191
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
192
+ )
193
+
194
+ self .assertEquals (
195
+ len (expected_ops_unseen ),
196
+ 0 ,
197
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
198
+ )
199
+ torch ._dynamo .reset ()
200
+
201
+ # Validate that the results between Torch and Torch-TRT are similar
202
+ optimized_model = torch_tensorrt .compile (
203
+ fx_graph ,
204
+ "torch_compile" ,
205
+ inputs ,
206
+ min_block_size = 1 ,
207
+ pass_through_build_failures = True ,
208
+ )
209
+ optimized_model_results = torch .cat (
210
+ [tensor .detach ().cpu () for tensor in optimized_model (* inputs )]
211
+ )
212
+ torch_model_results = torch .cat (
213
+ [tensor .detach ().cpu () for tensor in fx_graph (* inputs )]
214
+ )
215
+
216
+ max_diff = float (
217
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
218
+ )
219
+ self .assertAlmostEqual (
220
+ max_diff ,
221
+ 0 ,
222
+ DECIMALS_OF_AGREEMENT ,
223
+ msg = f"EfficientAttention TRT outputs don't match with the original model." ,
224
+ )
225
+ torch ._dynamo .reset ()
226
+
227
+ def test_efficient_attention_converter (self ):
228
+ class EfficientAttention (torch .nn .Module ):
229
+ def forward (self , q , k , v ):
230
+ attn = torch .ops .aten ._scaled_dot_product_efficient_attention .default (
231
+ q , k , v , None , False
232
+ )
233
+ return attn [0 ]
234
+
235
+ inputs = [
236
+ torch .rand (1 , 3 , 6 , 4 ).cuda (),
237
+ torch .rand (1 , 3 , 2 , 4 ).cuda (),
238
+ torch .rand (1 , 3 , 2 , 4 ).cuda (),
239
+ ]
240
+
241
+ fx_graph = torch .fx .symbolic_trace (EfficientAttention ())
242
+
243
+ # Validate that the results between Torch and Torch-TRT are similar
244
+ optimized_model = torch_tensorrt .compile (
245
+ fx_graph ,
246
+ "torch_compile" ,
247
+ inputs ,
248
+ min_block_size = 1 ,
249
+ pass_through_build_failures = True ,
250
+ )
251
+ optimized_model_results = torch .cat (
252
+ [tensor .detach ().cpu () for tensor in optimized_model (* inputs )]
253
+ )
254
+ torch_model_results = torch .cat (
255
+ [tensor .detach ().cpu () for tensor in fx_graph (* inputs )]
256
+ )
257
+
258
+ max_diff = float (
259
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
260
+ )
261
+ self .assertAlmostEqual (
262
+ max_diff ,
263
+ 0 ,
264
+ DECIMALS_OF_AGREEMENT ,
265
+ msg = f"EfficientAttention TRT outputs don't match with the original model." ,
155
266
)
156
267
torch ._dynamo .reset ()
157
268
0 commit comments