Skip to content

Commit 4acc3fd

Browse files
committed
feat(//core/lowering): New freeze model pass and new exception
elimination pass Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 90c44b9 commit 4acc3fd

17 files changed

+188
-46
lines changed

Diff for: core/conversion/converters/impl/linear.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ namespace impl {
99
namespace {
1010

1111
auto linear_registrations = RegisterNodeConversionPatterns()
12+
// .pattern({
13+
// "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
14+
// [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> {
15+
// auto in = args[0].ITensor();
16+
17+
// }
18+
// })
1219
.pattern({
1320
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
1421
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

Diff for: core/lowering/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ cc_library(
1111
],
1212
deps = [
1313
"@libtorch//:libtorch",
14-
"//core/lowering/irfusers"
14+
"//core/lowering/passes",
15+
"//core/util:prelude"
1516
]
1617
)
1718

Diff for: core/lowering/lowering.cpp

+22-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#include "torch/csrc/jit/passes/dead_code_elimination.h"
22
#include "torch/csrc/jit/passes/fuse_linear.h"
3+
#include "torch/csrc/jit/passes/freeze_module.h"
34
#include "torch/csrc/jit/passes/lower_graph.h"
45
#include "torch/csrc/jit/passes/quantization.h"
6+
#include "torch/csrc/jit/passes/guard_elimination.h"
57

8+
#include "core/util/prelude.h"
69
#include "core/lowering/lowering.h"
7-
#include "core/lowering/irfusers/irfusers.h"
10+
#include "core/lowering/passes/passes.h"
811

912
namespace trtorch {
1013
namespace core {
@@ -17,30 +20,36 @@ void LowerBlock(torch::jit::Block* b) {
1720
}
1821

1922
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
23+
torch::jit::EliminateRedundantGuards(g);
24+
passes::EliminateExceptionOrPassPattern(g);
2025
torch::jit::FuseLinear(g);
21-
irfusers::RemoveDropout(g);
22-
irfusers::FuseFlattenLinear(g);
23-
irfusers::ExpandLogSoftmax(g);
26+
passes::RemoveDropout(g);
27+
passes::FuseFlattenLinear(g);
28+
passes::ExpandLogSoftmax(g);
29+
//passes::RemoveDimExeception(g);
2430
//irfusers::UnpackBatchNorm(g);
25-
//torch::jit::EliminateDeadCode(g);
31+
torch::jit::EliminateDeadCode(g);
32+
LOG_GRAPH(*g);
2633
}
2734

28-
void LowerModule(const torch::jit::script::Module& mod) {
29-
torch::jit::FoldConvBatchNorm2d(mod);
35+
torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
36+
auto mod_ = torch::jit::freeze_module(mod);
37+
return mod_;
3038
}
3139

3240
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
3341
std::string method_name) {
34-
LowerModule(mod);
35-
auto g = mod.get_method(method_name).graph();
36-
// Go through PyTorch Lowering to simplify graph and extract weight parameters
37-
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());
38-
39-
g = graph_and_parameters.first;
42+
auto lowered_mod = LowerModule(mod);
43+
auto g = lowered_mod.get_method(method_name).graph();
44+
LOG_GRAPH(*g);
4045

4146
// Go through TRTorch Lowering to reformat graph to be conversion friendly
4247
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
48+
LOG_GRAPH("TRTorch Graph Lowering");
4349
lowering::LowerGraph(g);
50+
//=[torch::jit::FoldConvBatchNorm2d(lowered_mod);
51+
LOG_GRAPH("LibTorch Lowering");
52+
auto graph_and_parameters = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
4453
// Is this necessary?
4554
lowering::LowerBlock(g->block());
4655
return graph_and_parameters;

Diff for: core/lowering/lowering.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace lowering {
88

99
void LowerBlock(torch::jit::Block* b);
1010
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
11-
void LowerModule(const torch::jit::script::Module& mod);
11+
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
1212
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<at::Tensor>> Lower(const torch::jit::script::Module& mod,
1313
std::string method_name);
1414

Diff for: core/lowering/irfusers/BUILD renamed to core/lowering/passes/BUILD

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
package(default_visibility = ["//visibility:public"])
22

33
cc_library(
4-
name = "irfusers",
4+
name = "passes",
55
hdrs = [
6-
"irfusers.h",
6+
"passes.h",
77
],
88
srcs = [
99
"fuse_flatten_linear.cpp",
1010
"expand_log_softmax.cpp",
1111
"remove_dropout.cpp",
12-
"unpack_batch_norm.cpp"
12+
"unpack_batch_norm.cpp",
13+
"exception_elimination.cpp"
1314
],
1415
deps = [
16+
"//core/util:prelude",
1517
"@libtorch//:libtorch",
1618
]
1719
)

Diff for: core/lowering/passes/exception_elimination.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "torch/csrc/jit/passes/guard_elimination.h"
2+
#include "torch/csrc/jit/ir/alias_analysis.h"
3+
#include "torch/csrc/jit/jit_log.h"
4+
#include "torch/csrc/jit/passes/constant_propagation.h"
5+
#include "torch/csrc/jit/passes/peephole.h"
6+
#include "torch/csrc/jit/runtime/graph_executor.h"
7+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
8+
9+
#include "core/util/prelude.h"
10+
11+
#include <vector>
12+
13+
namespace trtorch {
14+
namespace core {
15+
namespace lowering {
16+
namespace passes {
17+
namespace {
18+
using namespace torch::jit;
19+
struct ExceptionOrPassPatternElimination {
20+
ExceptionOrPassPatternElimination(std::shared_ptr<Graph> graph)
21+
: graph_(std::move(graph)) {}
22+
23+
void run() {
24+
LOG_GRAPH("Pre exeception or pass elimination: " << *graph_);
25+
findExceptionOrPassNodes(graph_->block());
26+
torch::jit::EliminateDeadCode(graph_);
27+
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);
28+
}
29+
30+
private:
31+
bool isExceptionOrPassNode(Node* n) {
32+
/// Check if this Node hosts a pattern like so:
33+
/// = prim::If(%5958)
34+
/// block0():
35+
/// = prim::RaiseException(%45)
36+
/// -> ()
37+
/// block1():
38+
/// -> ()
39+
if (n->blocks().size() != 2) {
40+
return false;
41+
}
42+
auto arm1 = n->blocks()[0];
43+
auto arm2 = n->blocks()[1];
44+
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
45+
// Make sure that the node doesn't actually produce any Value that are used by other nodes
46+
return false;
47+
}
48+
49+
auto arm1_start = arm1->nodes().begin();
50+
51+
if ((*arm1_start)->kind() != prim::RaiseException && (*(++arm1_start))->kind() != prim::Return) {
52+
// Make sure that block0 is solely just the exception and the return
53+
return false;
54+
}
55+
56+
if ((*(arm2->nodes().begin()))->kind() != prim::Return) {
57+
// Make sure that block1 is solely the return
58+
return false;
59+
}
60+
61+
return true;
62+
}
63+
64+
void findExceptionOrPassNodes(Block* b) {
65+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
66+
auto n = *it;
67+
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
68+
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)");
69+
it.destroyCurrent();
70+
}
71+
}
72+
}
73+
74+
std::shared_ptr<Graph> graph_;
75+
};
76+
} // namespace
77+
78+
void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
79+
ExceptionOrPassPatternElimination eppe(std::move(graph));
80+
eppe.run();
81+
}
82+
83+
} // namespace passes
84+
} // namespace lowering
85+
} // namespace core
86+
} // namespace trtorch

Diff for: core/lowering/irfusers/expand_log_softmax.cpp renamed to core/lowering/passes/expand_log_softmax.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
namespace trtorch {
55
namespace core {
66
namespace lowering {
7-
namespace irfusers {
7+
namespace passes {
88

99
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
1010
// Its easier for TensorRT if we seperate softmax and log
1111
// There might need to be a reshape inserted see:
1212
// https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593
1313
// Should the reshapes be added here or in the converter?
14-
14+
1515
// TODO: In the future this should be removed for a deicated log_softmax converter (more efficent)
1616
// But its easier to stand up a working system if the number of op converters is lower
1717
std::string logsoftmax_pattern = R"IR(
@@ -33,19 +33,19 @@ void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
3333
%dtype : int? = prim::Constant()
3434
%softmax = aten::softmax(%input, %dim, %dtype)
3535
%log_softmax = aten::log(%softmax)
36-
return (%log_softmax))IR";
36+
return (%log_softmax))IR";
3737

3838
torch::jit::SubgraphRewriter logsoftmax_to_softmax_log;
3939
logsoftmax_to_softmax_log.RegisterRewritePattern(logsoftmax_pattern, softmax_log_pattern);
4040
logsoftmax_to_softmax_log.runOnGraph(graph);
41-
41+
4242
torch::jit::SubgraphRewriter logsoftmax_none_to_softmax_log_none;
4343
logsoftmax_none_to_softmax_log_none.RegisterRewritePattern(
4444
logsoftmax_none_pattern, softmax_log_none_pattern);
4545
logsoftmax_none_to_softmax_log_none.runOnGraph(graph);
4646
}
4747

48-
} // namespace irfusers
48+
} // namespace passes
4949
} // namespace lowering
5050
} // namespace core
5151
} // namespace trtorch

Diff for: core/lowering/irfusers/fuse_flatten_linear.cpp renamed to core/lowering/passes/fuse_flatten_linear.cpp

+37-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
namespace trtorch {
55
namespace core {
66
namespace lowering {
7-
namespace irfusers {
7+
namespace passes {
88

99
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
1010
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
@@ -33,13 +33,47 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
3333
torch::jit::SubgraphRewriter flatten_linear_to_linear;
3434
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
3535
flatten_linear_to_linear.runOnGraph(graph);
36-
36+
37+
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
38+
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
39+
flatten_linear_bias_none_pattern, fused_linear_bias_none);
40+
flatten_linear_bias_none_to_linear.runOnGraph(graph);
41+
}
42+
43+
void FuseFlattenAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
44+
//TensorRT implicitly adds a flatten layer infront of FC layers if necessary
45+
std::string flatten_linear_pattern = R"IR(
46+
graph(%input, %6, %7, %weight, %bias):
47+
%flat = aten::flatten(%input, %6, %7)
48+
%res = aten::linear(%flat, %weight, %bias)
49+
return (%res))IR";
50+
std::string flatten_linear_bias_none_pattern = R"IR(
51+
graph(%input, %6, %7, %weight):
52+
%flat = aten::flatten(%input, %6, %7)
53+
%bias: Tensor? = prim::Constant()
54+
%res = aten::linear(%flat, %weight, %bias)
55+
return (%res))IR";
56+
std::string fused_linear = R"IR(
57+
graph(%input, %6, %7, %weight, %bias):
58+
%res = aten::linear(%input, %weight, %bias)
59+
return (%res))IR";
60+
61+
std::string fused_linear_bias_none = R"IR(
62+
graph(%input, %6, %7, %weight):
63+
%bias: Tensor? = prim::Constant()
64+
%res = aten::linear(%input, %weight, %bias)
65+
return (%res))IR";
66+
67+
torch::jit::SubgraphRewriter flatten_linear_to_linear;
68+
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
69+
flatten_linear_to_linear.runOnGraph(graph);
70+
3771
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
3872
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
3973
flatten_linear_bias_none_pattern, fused_linear_bias_none);
4074
flatten_linear_bias_none_to_linear.runOnGraph(graph);
4175
}
42-
} // namespace irfusers
76+
} // namespace passes
4377
} // namespace lowering
4478
} // namespace core
4579
} // namespace trtorch

Diff for: core/lowering/irfusers/irfusers.h renamed to core/lowering/passes/passes.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
namespace trtorch {
66
namespace core {
77
namespace lowering {
8-
namespace irfusers {
8+
namespace passes {
99

1010
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1111
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
1212
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1313
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
14+
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1415

1516
} // namespace irfusers
1617
} // namespace lowering

Diff for: core/lowering/irfusers/remove_dropout.cpp renamed to core/lowering/passes/remove_dropout.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
namespace trtorch {
55
namespace core {
66
namespace lowering {
7-
namespace irfusers {
7+
namespace passes {
88

99
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
1010
std::string dropout_pattern = R"IR(
@@ -14,15 +14,15 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
1414
std::string no_dropout_pattern = R"IR(
1515
graph(%input, %4, %5):
1616
return (%input))IR";
17-
17+
1818
// replace matmul + add pattern to linear
1919
torch::jit::SubgraphRewriter remove_dropout;
2020
remove_dropout.RegisterRewritePattern(
2121
dropout_pattern, no_dropout_pattern);
2222
remove_dropout.runOnGraph(graph);
2323
}
2424

25-
} // namespace irfusers
25+
} // namespace passes
2626
} // namespace lowering
2727
} // namespace core
2828
} // namespace trtorch

Diff for: core/lowering/irfusers/unpack_batch_norm.cpp renamed to core/lowering/passes/unpack_batch_norm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RegisterOperators trt_const_op_reg({
2323
namespace trtorch {
2424
namespace core {
2525
namespace lowering {
26-
namespace irfusers {
26+
namespace passes {
2727

2828
// // May be abusing aten::_tensor_to_list(Tensor self) -> int[]
2929
// // Treating it as an emit_constant by the converters
@@ -60,7 +60,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph) {
6060
unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern);
6161
unpack_batch_norm.runOnGraph(graph);
6262
}
63-
} // Namespace Irfusers
63+
} // Namespace passes
6464
} // namespace lowering
6565
} // namespace core
6666
} // namespace trtorch

Diff for: core/util/logging/TRTorchLogger.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ namespace {
101101
TRTorchLogger& get_global_logger() {
102102
#ifndef NDEBUG
103103
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ",
104-
LogLevel::kDEBUG,
104+
LogLevel::kGRAPH,
105105
true);
106106
#else
107107
static TRTorchLogger global_logger("[TRTorch] - ",

Diff for: core/util/macros.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,21 @@
1111
l.log(sev, ss.str()); \
1212
} while (0)
1313

14-
#define GRAPH_DUMP_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s)
14+
#define LOG_GRAPH_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kGRAPH, s)
1515
#define LOG_DEBUG_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kDEBUG, s)
1616
#define LOG_INFO_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINFO, s)
1717
#define LOG_WARNING_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kWARNING, s)
1818
#define LOG_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kERROR, s)
1919
#define LOG_INTERNAL_ERROR_GLOBAL(s) TRTORCH_LOG(core::util::logging::get_logger(), core::util::logging::LogLevel::kINTERNAL_ERROR, s)
2020

21-
#define GRAPH_DUMP_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s)
21+
#define LOG_GRAPH_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kGRAPH, s)
2222
#define LOG_DEBUG_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kDEBUG, s)
2323
#define LOG_INFO_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINFO, s)
2424
#define LOG_WARNING_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kWARNING, s)
2525
#define LOG_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kERROR, s)
2626
#define LOG_INTERNAL_ERROR_OWN(l,s) TRTORCH_LOG(l, core::util::logging::LogLevel::kINTERNAL_ERROR, s)
2727

28-
#define GRAPH_DUMP(...) GET_MACRO(__VA_ARGS__, GRAPH_DUMP_OWN, GRAPH_DUMP_GLOBAL)(__VA_ARGS__)
28+
#define LOG_GRAPH(...) GET_MACRO(__VA_ARGS__, LOG_GRAPH_OWN, LOG_GRAPH_GLOBAL)(__VA_ARGS__)
2929
#define LOG_DEBUG(...) GET_MACRO(__VA_ARGS__, LOG_DEBUG_OWN, LOG_DEBUG_GLOBAL)(__VA_ARGS__)
3030
#define LOG_INFO(...) GET_MACRO(__VA_ARGS__, LOG_INFO_OWN, LOG_INFO_GLOBAL)(__VA_ARGS__)
3131
#define LOG_WARNING(...) GET_MACRO(__VA_ARGS__, LOG_WARNING_OWN, LOG_WARNING_GLOBAL)(__VA_ARGS__)

0 commit comments

Comments
 (0)