Skip to content

Commit 370aeb9

Browse files
committed
fix(//core/lowering): use lower_info as parameter
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 74bbd10 commit 370aeb9

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

Diff for: core/compiler.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ void AddEngineToGraph(
118118
}
119119

120120
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
121-
// Go through Lowering to simplify graph and extract weight parameters
122-
auto graph_and_parameters = lowering::Lower(mod, method_name, false);
121+
// Go through Lowering to simplify graph
122+
CompileSpec cfg({});
123+
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
123124

124125
auto g = graph_and_parameters.first;
125126
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
@@ -129,7 +130,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
129130

130131
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
131132
// Go through Lowering to simplify graph and extract weight parameters
132-
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info.unfreeze_module);
133+
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
133134

134135
auto convert_cfg = std::move(cfg.convert_info);
135136
auto g = graph_and_parameters.first;
@@ -187,7 +188,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
187188
// Compile only forward methods. forward method contains the entire graph.
188189
if (method.name().compare("forward") == 0) {
189190
auto new_g = std::make_shared<torch::jit::Graph>();
190-
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info.unfreeze_module);
191+
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
191192

192193
auto g = graph_and_parameters.first;
193194
auto params = graph_and_parameters.second;

Diff for: core/lowering/lowering.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
6565
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
6666
const torch::jit::script::Module& mod,
6767
std::string method_name,
68-
bool unfreeze_module = false) {
69-
auto lowered_mod = unfreeze_module ? mod : LowerModule(mod);
68+
LowerInfo lower_info) {
69+
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod);
7070
auto g = lowered_mod.get_method(method_name).graph();
7171
LOG_GRAPH(*g);
7272

@@ -75,15 +75,15 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
7575
// unfreeze_module is used to not perform constant folding on weights in the network.
7676
// In quantization aware trained (QAT) models, weights are passed through quantize and
7777
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
78-
if (!unfreeze_module) {
78+
if (!lower_info.unfreeze_module) {
7979
LOG_GRAPH("TRTorch Graph Lowering");
8080
lowering::LowerGraph(g, false);
8181
}
8282

8383
LOG_GRAPH("LibTorch Lowering");
8484
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
8585

86-
if (unfreeze_module) {
86+
if (lower_info.unfreeze_module) {
8787
LOG_GRAPH("TRTorch Graph Lowering");
8888
lowering::LowerGraph(graph_and_ivalues.first, true);
8989
}

Diff for: core/lowering/lowering.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ struct LowerInfo {
1212
};
1313

1414
void LowerBlock(torch::jit::Block* b);
15-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse /*=false*/);
15+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse=false);
1616
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
1717
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
1818
const torch::jit::script::Module& mod,
1919
std::string method_name,
20-
bool unfreeze_module /*=false*/);
20+
LowerInfo lower_info);
2121

2222
} // namespace lowering
2323
} // namespace core

Diff for: tests/core/conversion/converters/test_activation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ TEST(Converters, ATenSigmoidConvertsCorrectly) {
4141
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
4242
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
4343

44-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 4e-6));
44+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 5e-6));
4545
}
4646

4747
TEST(Converters, ATenTanhConvertsCorrectly) {

0 commit comments

Comments
 (0)