@@ -437,3 +437,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
437
437
ASSERT_TRUE (
438
438
torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
439
439
}
440
+
441
+ TEST (LoweringPasses, RemoveAtenIntTensorValuesAgree) {
442
+ std::string source_graph_no_inputs = R"IR(
443
+ graph():
444
+ %0: int = prim::Constant[value=2]()
445
+ %11: int = prim::Constant[value=7]()
446
+ %3: Tensor = prim::NumToTensor(%0)
447
+ %1: Tensor = prim::NumToTensor(%11)
448
+ %4: Tensor = aten::floor_divide(%1, %3)
449
+ %7: Tensor = aten::mul(%3, %4)
450
+ %8: Tensor = aten::mul(%7, %1)
451
+ %50: int = aten::Int(%8)
452
+ %5: Tensor = prim::NumToTensor(%50)
453
+ return (%5))IR" ;
454
+ std::string target_graph_no_inputs = R"IR(
455
+ graph():
456
+ %0: int = prim::Constant[value=2]()
457
+ %1: int = prim::Constant[value=7]()
458
+ %4: int = aten::floordiv(%1, %0)
459
+ %7: int = aten::mul(%0, %4)
460
+ %40: int = aten::mul(%7, %1)
461
+ %4: Tensor = prim::NumToTensor(%40)
462
+ return (%4))IR" ;
463
+
464
+ auto g_in = std::make_shared<torch::jit::Graph>();
465
+ auto g_out = std::make_shared<torch::jit::Graph>();
466
+
467
+ torch::jit::parseIR (source_graph_no_inputs, g_in.get ());
468
+ torch::jit::parseIR (target_graph_no_inputs, g_out.get ());
469
+
470
+ auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_in, {});
471
+ auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_out, {});
472
+
473
+ ASSERT_TRUE (
474
+ torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
475
+
476
+ // Ensure the lowering pass transforms the first graph into the second
477
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
478
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
479
+ auto sg = std::make_shared<torch::jit::Graph>();
480
+ torch::jit::parseIR (source_graph_no_inputs, sg.get ());
481
+
482
+ torch_tensorrt::core::lowering::passes::ReplaceAtenInt (sg);
483
+
484
+ auto tg = std::make_shared<torch::jit::Graph>();
485
+ torch::jit::parseIR (target_graph_no_inputs, tg.get ());
486
+
487
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
488
+ }
489
+
490
+ TEST (LoweringPasses, RemoveAtenIntSizeTensorValuesAgree) {
491
+ std::string source_graph_no_inputs = R"IR(
492
+ graph(%x.0: Tensor):
493
+ %10: int = prim::Constant[value=0]()
494
+ %100: int = aten::size(%x.0, %10)
495
+ %0: Tensor = prim::NumToTensor(%100)
496
+ %11: int = prim::Constant[value=9]()
497
+ %1: Tensor = prim::NumToTensor(%11)
498
+ %4: Tensor = aten::floor_divide(%1, %0)
499
+ %7: Tensor = aten::mul(%0, %4)
500
+ %8: Tensor = aten::mul(%7, %1)
501
+ %50: int = aten::Int(%8)
502
+ %5: Tensor = prim::NumToTensor(%50)
503
+ return (%5))IR" ;
504
+ std::string target_graph_no_inputs = R"IR(
505
+ graph(%x.0: Tensor):
506
+ %10: int = prim::Constant[value=0]()
507
+ %0: int = aten::size(%x.0, %10)
508
+ %1: int = prim::Constant[value=9]()
509
+ %4: int = aten::floordiv(%1, %0)
510
+ %7: int = aten::mul(%0, %4)
511
+ %40: int = aten::mul(%7, %1)
512
+ %4: Tensor = prim::NumToTensor(%40)
513
+ return (%4))IR" ;
514
+
515
+ auto g_in = std::make_shared<torch::jit::Graph>();
516
+ auto g_out = std::make_shared<torch::jit::Graph>();
517
+
518
+ auto in_0 = at::rand ({2 , 3 , 5 , 5 }, {at::kCUDA });
519
+
520
+ torch::jit::parseIR (source_graph_no_inputs, g_in.get ());
521
+ torch::jit::parseIR (target_graph_no_inputs, g_out.get ());
522
+
523
+ auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_in, {in_0});
524
+ auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT (g_out, {in_0});
525
+
526
+ ASSERT_TRUE (
527
+ torch_tensorrt::tests::util::almostEqual (jit_pre_results[0 ].toTensor (), jit_post_results[0 ].toTensor (), 2e-6 ));
528
+
529
+ // Ensure the lowering pass transforms the first graph into the second
530
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
531
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
532
+ auto sg = std::make_shared<torch::jit::Graph>();
533
+ torch::jit::parseIR (source_graph_no_inputs, sg.get ());
534
+
535
+ torch_tensorrt::core::lowering::passes::ReplaceAtenInt (sg);
536
+
537
+ auto tg = std::make_shared<torch::jit::Graph>();
538
+ torch::jit::parseIR (target_graph_no_inputs, tg.get ());
539
+
540
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
541
+ }
542
+
543
+ TEST (LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
544
+ // Ensure the lowering pass transforms the first graph into the second
545
+ std::string source_graph = R"IR(
546
+ graph(%0: int):
547
+ %1: Tensor = prim::Constant[value=[8]]()
548
+ %3: Tensor = prim::NumToTensor(%0)
549
+ %4: Tensor = aten::floor_divide(%3, %1)
550
+ %5: int = aten::Int(%4)
551
+ return (%5))IR" ;
552
+
553
+ std::string target_graph = R"IR(
554
+ graph(%0 : int):
555
+ %1 : Tensor = prim::Constant[value=[8]]()
556
+ %2 : int = prim::Constant[value=8]()
557
+ %3 : int = aten::floordiv(%0, %2)
558
+ return (%3))IR" ;
559
+
560
+ auto sg = std::make_shared<torch::jit::Graph>();
561
+ torch::jit::parseIR (source_graph, &*sg);
562
+
563
+ // Manually enter 0d tensor const for source
564
+ auto first_op_sg = *(sg->block ()->nodes ().begin ());
565
+ torch::jit::Value* r_sg = sg->insertConstant (c10::scalar_to_tensor (8 ), c10::nullopt, first_op_sg->scope ());
566
+ r_sg->copyMetadata (first_op_sg->output ());
567
+ r_sg->setType (c10::TensorType::get ());
568
+ first_op_sg->output ()->replaceAllUsesWith (r_sg);
569
+ first_op_sg->destroy ();
570
+
571
+ torch_tensorrt::core::lowering::passes::ReplaceAtenInt (sg);
572
+ torch::jit::ConstantPooling (sg);
573
+ sg = torch::jit::Canonicalize (sg, false );
574
+
575
+ auto tg = std::make_shared<torch::jit::Graph>();
576
+ torch::jit::parseIR (target_graph, &*tg);
577
+
578
+ // Manually enter 0d tensor const for target
579
+ auto first_op_tg = *(tg->block ()->nodes ().begin ());
580
+ torch::jit::Value* r_tg = tg->insertConstant (c10::scalar_to_tensor (8 ), c10::nullopt, first_op_tg->scope ());
581
+ r_tg->copyMetadata (first_op_tg->output ());
582
+ r_tg->setType (c10::TensorType::get ());
583
+ first_op_tg->output ()->replaceAllUsesWith (r_tg);
584
+ first_op_tg->destroy ();
585
+
586
+ torch::jit::ConstantPooling (tg);
587
+ tg = torch::jit::Canonicalize (tg, false );
588
+
589
+ // Validate identical graphs after pooling constants and canonicalizing
590
+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
591
+ }
0 commit comments