@@ -313,6 +313,185 @@ def forward(self, x):
313
313
f"Var TRT outputs don't match with the original model." ,
314
314
)
315
315
316
+ def test_lowering_maxpool1d_functional (self ):
317
+ class MaxPool1d (torch .nn .Module ):
318
+ def forward (self , x ):
319
+ y = torch .nn .functional .max_pool1d (x , 3 )
320
+ return y
321
+
322
+ # Operations expected to be removed in the traced graph after decompositions
323
+ expected_ops = {torch .ops .aten .max_pool2d .default }
324
+ unexpected_ops = {
325
+ torch .ops .aten .max_pool1d_with_indices .default ,
326
+ torch .ops .aten .max_pool2d_with_indices .default ,
327
+ }
328
+
329
+ inputs = [torch .randn (4 , 8 , 27 ).cuda ()]
330
+
331
+ fx_graph = torch .fx .symbolic_trace (MaxPool1d ())
332
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
333
+ fx_graph ,
334
+ inputs ,
335
+ expected_ops = expected_ops ,
336
+ unexpected_ops = unexpected_ops ,
337
+ min_block_size = 1 ,
338
+ )
339
+
340
+ self .assertEquals (
341
+ len (unexpected_ops_seen ),
342
+ 0 ,
343
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
344
+ )
345
+
346
+ self .assertEquals (
347
+ len (expected_ops_unseen ),
348
+ 0 ,
349
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
350
+ )
351
+
352
+ torch ._dynamo .reset ()
353
+
354
+ # Validate that the results between Torch and Torch-TRT are similar
355
+ optimized_model = torch_tensorrt .compile (
356
+ fx_graph ,
357
+ "torch_compile" ,
358
+ inputs ,
359
+ min_block_size = 1 ,
360
+ pass_through_build_failures = True ,
361
+ )
362
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
363
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
364
+
365
+ max_diff = float (
366
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
367
+ )
368
+ self .assertAlmostEqual (
369
+ max_diff ,
370
+ 0 ,
371
+ DECIMALS_OF_AGREEMENT ,
372
+ f"MaxPool1d TRT outputs don't match with the original model." ,
373
+ )
374
+
375
+ def test_lowering_maxpool_2d_module (self ):
376
+ class MaxPool2d (torch .nn .Module ):
377
+ def __init__ (self , * args , ** kwargs ) -> None :
378
+ super ().__init__ (* args , ** kwargs )
379
+ self .maxpool = torch .nn .MaxPool2d ((5 , 3 ), stride = (2 , 1 ))
380
+
381
+ def forward (self , x ):
382
+ y = self .maxpool (x )
383
+ return y
384
+
385
+ # Operations expected to be removed in the traced graph after decompositions
386
+ expected_ops = {torch .ops .aten .max_pool2d .default }
387
+ unexpected_ops = {torch .ops .aten .max_pool2d_with_indices .default }
388
+
389
+ inputs = [torch .randn (1 , 3 , 25 , 30 ).cuda ()]
390
+
391
+ fx_graph = torch .fx .symbolic_trace (MaxPool2d ())
392
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
393
+ fx_graph ,
394
+ inputs ,
395
+ expected_ops = expected_ops ,
396
+ unexpected_ops = unexpected_ops ,
397
+ min_block_size = 1 ,
398
+ )
399
+
400
+ self .assertEquals (
401
+ len (unexpected_ops_seen ),
402
+ 0 ,
403
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
404
+ )
405
+
406
+ self .assertEquals (
407
+ len (expected_ops_unseen ),
408
+ 0 ,
409
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
410
+ )
411
+
412
+ torch ._dynamo .reset ()
413
+
414
+ # Validate that the results between Torch and Torch-TRT are similar
415
+ optimized_model = torch_tensorrt .compile (
416
+ fx_graph ,
417
+ "torch_compile" ,
418
+ inputs ,
419
+ min_block_size = 1 ,
420
+ pass_through_build_failures = True ,
421
+ )
422
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
423
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
424
+
425
+ max_diff = float (
426
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
427
+ )
428
+ self .assertAlmostEqual (
429
+ max_diff ,
430
+ 0 ,
431
+ DECIMALS_OF_AGREEMENT ,
432
+ f"MaxPool2d TRT outputs don't match with the original model." ,
433
+ )
434
+
435
+ def test_lowering_maxpool_3d_module (self ):
436
+ class MaxPool3d (torch .nn .Module ):
437
+ def __init__ (self , * args , ** kwargs ) -> None :
438
+ super ().__init__ (* args , ** kwargs )
439
+ self .maxpool = torch .nn .MaxPool3d (3 )
440
+
441
+ def forward (self , x ):
442
+ y = self .maxpool (x )
443
+ return y
444
+
445
+ # Operations expected to be removed in the traced graph after decompositions
446
+ expected_ops = {torch .ops .aten .max_pool3d .default }
447
+ unexpected_ops = {torch .ops .aten .max_pool3d_with_indices .default }
448
+
449
+ inputs = [torch .randn (4 , 8 , 27 , 72 , 96 ).cuda ()]
450
+
451
+ fx_graph = torch .fx .symbolic_trace (MaxPool3d ())
452
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
453
+ fx_graph ,
454
+ inputs ,
455
+ expected_ops = expected_ops ,
456
+ unexpected_ops = unexpected_ops ,
457
+ min_block_size = 1 ,
458
+ )
459
+
460
+ self .assertEquals (
461
+ len (unexpected_ops_seen ),
462
+ 0 ,
463
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
464
+ )
465
+
466
+ self .assertEquals (
467
+ len (expected_ops_unseen ),
468
+ 0 ,
469
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
470
+ )
471
+
472
+ torch ._dynamo .reset ()
473
+
474
+ # Validate that the results between Torch and Torch-TRT are similar
475
+ optimized_model = torch_tensorrt .compile (
476
+ fx_graph ,
477
+ "torch_compile" ,
478
+ inputs ,
479
+ min_block_size = 1 ,
480
+ pass_through_build_failures = True ,
481
+ )
482
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
483
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
484
+
485
+ max_diff = float (
486
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
487
+ )
488
+ self .assertAlmostEqual (
489
+ max_diff ,
490
+ 0 ,
491
+ DECIMALS_OF_AGREEMENT ,
492
+ f"MaxPool3d TRT outputs don't match with the original model." ,
493
+ )
494
+
316
495
317
496
if __name__ == "__main__" :
318
497
run_tests ()
0 commit comments