Skip to content

Commit fdec075

Browse files
authored
Merge pull request #1298 from pytorch/param_input
feat: rewriting param to a Constant if it's a introduced input
2 parents fe966ed + b09994d commit fdec075

File tree

7 files changed

+84
-2
lines changed

7 files changed

+84
-2
lines changed

core/lowering/lowering.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void LowerBlock(torch::jit::Block* b) {
2626
DropUnusedNodes(b);
2727
}
2828

29-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
29+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
3030
torch::jit::EliminateRedundantGuards(g);
3131
torch::jit::RemoveListMutation(g);
3232
torch::jit::RemoveTensorMutation(g);
@@ -70,6 +70,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::RewriteInputsWithParams(g, params);
7374
LOG_GRAPH(*g);
7475
}
7576

@@ -103,7 +104,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
103104
// In quantization aware trained (QAT) models, weights are passed through quantize and
104105
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
105106
LOG_GRAPH("Torch-TensorRT.TorchScript Graph Lowering");
106-
lowering::LowerGraph(graph_and_ivalues.first, lower_info);
107+
lowering::LowerGraph(graph_and_ivalues.first, graph_and_ivalues.second, lower_info);
107108

108109
// Is this necessary?
109110
// lowering::LowerBlock(g->block());

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cc_library(
2727
"remove_dropout.cpp",
2828
"remove_nops.cpp",
2929
"remove_unnecessary_casts.cpp",
30+
"rewrite_inputs_with_params.cpp",
3031
"silu_to_sigmoid_multiplication.cpp",
3132
"unpack_addmm.cpp",
3233
"unpack_batch_norm.cpp",

core/lowering/passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ target_sources(${lib_name}
2424
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2525
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2626
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"
27+
"${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp"
2728
)
2829

2930
set(HEADER_FILES

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3939
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
4040
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
42+
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
4243
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
4344

4445
} // namespace passes
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include "core/util/prelude.h"
2+
#include "torch/csrc/jit/ir/constants.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params) {
10+
auto input_size = g->inputs().size();
11+
auto param_it = params.rbegin();
12+
for (int i = input_size - 1; i >= 0; --i) {
13+
if (g->inputs()[i]->type() != c10::TensorType::get() &&
14+
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::TupleType &&
15+
g->inputs()[i]->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.rend()) {
16+
auto val = *param_it;
17+
if (val.isTensor()) {
18+
at::Tensor val_tensor = val.toTensor();
19+
if (val_tensor.requires_grad()) {
20+
val_tensor.set_requires_grad(false);
21+
val = val_tensor;
22+
}
23+
}
24+
auto new_constant = torch::jit::tryInsertConstant(*g, val);
25+
++param_it;
26+
if (new_constant) {
27+
g->inputs()[i]->replaceAllUsesWith(*new_constant);
28+
g->eraseInput(i);
29+
// erase an iterator, should be safe
30+
params.erase(param_it.base());
31+
}
32+
}
33+
}
34+
LOG_GRAPH("After RewriteInputsWithParams: " << *g);
35+
}
36+
37+
} // namespace passes
38+
} // namespace lowering
39+
} // namespace core
40+
} // namespace torch_tensorrt

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ lowering_test(
8787
name = "test_unpack_reduce_ops",
8888
)
8989

90+
lowering_test(
91+
name = "test_rewrite_inputs_with_params",
92+
)
93+
9094
test_suite(
9195
name = "lowering_tests",
9296
tests = [
@@ -102,6 +106,7 @@ test_suite(
102106
":test_remove_detach_pass",
103107
":test_remove_dropout_pass",
104108
":test_remove_unnecessary_casts",
109+
":test_rewrite_inputs_with_params",
105110
":test_unpack_hardsigmoid",
106111
":test_unpack_hardswish",
107112
":test_unpack_reduce_ops",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x: Tensor, %y: Tensor, %1 : Int(1)):
12+
%out: Tensor = aten::sub(%x, %y, %1)
13+
return (%out))IR";
14+
std::string target_graph = R"IR(
15+
graph(%x: Tensor, %y : Tensor):
16+
%2 : int = prim::Constant[value=0]()
17+
%out: Tensor = aten::sub(%x, %y, %2)
18+
return (%out))IR";
19+
20+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
21+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
22+
auto sg = std::make_shared<torch::jit::Graph>();
23+
torch::jit::parseIR(source_graph, &*sg);
24+
25+
torch::jit::IValue param0 = torch::jit::IValue(0);
26+
std::vector<torch::jit::IValue> params{param0};
27+
torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params);
28+
29+
auto tg = std::make_shared<torch::jit::Graph>();
30+
torch::jit::parseIR(target_graph, &*tg);
31+
32+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
33+
}

0 commit comments

Comments
 (0)