Skip to content

Commit 9ee3a04

Browse files
committed
fix: Implement a patch for gelu schema change in older NGC containers
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent d6694db commit 9ee3a04

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

Diff for: core/lowering/passes/reduce_gelu.cpp

+33-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@ namespace passes {
88

99
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
1010
std::string gelu_pattern = R"IR(
11-
graph(%x):
11+
graph(%x : Tensor):
1212
%out : Tensor = aten::gelu(%x)
1313
return (%out))IR";
1414

15+
// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
16+
// an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
17+
std::string gelu_approximate_pattern = R"IR(
18+
graph(%x : Tensor, %approx):
19+
%out : Tensor = aten::gelu(%x, %approx)
20+
return (%out))IR";
21+
1522
std::string gelu_reduce_pattern = R"IR(
1623
graph(%x.1 : Tensor):
1724
%6 : float = prim::Constant[value=0.044714999999999998]()
@@ -30,11 +37,36 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
3037
%15 : Tensor = aten::mul(%7, %14)
3138
return (%15))IR";
3239

40+
// This is same as gelu_reduce_pattern except for an additional input %approx.
41+
// SubgraphRewriter only works as expected if the number of inputs to gelu_approximate_pattern
42+
// and gelu_reduce_multi_input_pattern are same.
43+
std::string gelu_reduce_multi_input_pattern = R"IR(
44+
graph(%x.1 : Tensor, %approx):
45+
%6 : float = prim::Constant[value=0.044714999999999998]()
46+
%5 : float = prim::Constant[value=0.79788456080000003]()
47+
%4 : float = prim::Constant[value=1.]()
48+
%3 : float = prim::Constant[value=0.5]()
49+
%2 : int = prim::Constant[value=1]()
50+
%7 : Tensor = aten::mul(%x.1, %3)
51+
%8 : Tensor = aten::mul(%x.1, %5)
52+
%9 : Tensor = aten::mul(%x.1, %6)
53+
%10 : Tensor = aten::mul(%9, %x.1)
54+
%11 : Tensor = aten::add(%10, %4, %2)
55+
%12 : Tensor = aten::mul(%8, %11)
56+
%13 : Tensor = aten::tanh(%12)
57+
%14 : Tensor = aten::add(%13, %4, %2)
58+
%15 : Tensor = aten::mul(%7, %14)
59+
return (%15))IR";
60+
3361
// replace aten::gelu with pointwise operations
3462
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
3563
map_gelu_to_pointwise_ops.RegisterRewritePattern(gelu_pattern, gelu_reduce_pattern);
3664
map_gelu_to_pointwise_ops.runOnGraph(graph);
3765

66+
torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
67+
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
68+
map_gelu_approximate_to_pointwise_ops.runOnGraph(graph);
69+
3870
LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);
3971
}
4072

Diff for: tests/core/lowering/test_reduce_gelu.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,38 @@ TEST(LoweringPasses, ReduceGeluCorrectly) {
4040

4141
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
4242
}
43+
44+
TEST(LoweringPasses, ReduceGeluApproximateCorrectly) {
45+
std::string source_graph = R"IR(
46+
graph(%x, %approx):
47+
%out : Tensor = aten::gelu(%x, %approx)
48+
return (%out))IR";
49+
std::string target_graph = R"IR(
50+
graph(%x.1 : Tensor, %approx):
51+
%6 : float = prim::Constant[value=0.044714999999999998]()
52+
%5 : float = prim::Constant[value=0.79788456080000003]()
53+
%4 : float = prim::Constant[value=1.]()
54+
%3 : float = prim::Constant[value=0.5]()
55+
%2 : int = prim::Constant[value=1]()
56+
%7 : Tensor = aten::mul(%x.1, %3)
57+
%8 : Tensor = aten::mul(%x.1, %5)
58+
%9 : Tensor = aten::mul(%x.1, %6)
59+
%10 : Tensor = aten::mul(%9, %x.1)
60+
%11 : Tensor = aten::add(%10, %4, %2)
61+
%12 : Tensor = aten::mul(%8, %11)
62+
%13 : Tensor = aten::tanh(%12)
63+
%14 : Tensor = aten::add(%13, %4, %2)
64+
%15 : Tensor = aten::mul(%7, %14)
65+
return (%15))IR";
66+
67+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
68+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
69+
auto sg = std::make_shared<torch::jit::Graph>();
70+
torch::jit::parseIR(source_graph, &*sg);
71+
torch_tensorrt::core::lowering::passes::ReduceGelu(sg);
72+
73+
auto tg = std::make_shared<torch::jit::Graph>();
74+
torch::jit::parseIR(target_graph, &*tg);
75+
76+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
77+
}

0 commit comments

Comments
 (0)