Skip to content

Commit 85e5e99

Browse files
authored
feat: Rsqrt lowering pass (#1394)
* feat: Add lowering pass for rsqrt operator - Add unpack rsqrt lowering pass - Add test cases for positive inputs, int and float - Add references to new function in headers and BUILD files * Added UnpackRsqrt to lowering passes list
1 parent e27103a commit 85e5e99

File tree

6 files changed

+76
-0
lines changed

6 files changed

+76
-0
lines changed

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6060
passes::UnpackAddMM(g);
6161
// passes::UnpackBatchNorm(g);
6262
passes::UnpackLogSoftmax(g);
63+
passes::UnpackRsqrt(g);
6364
passes::UnpackStd(g);
6465
passes::UnpackVar(g);
6566
passes::RemoveNOPs(g);

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ cc_library(
3333
"unpack_hardsigmoid.cpp",
3434
"unpack_hardswish.cpp",
3535
"unpack_log_softmax.cpp",
36+
"unpack_rsqrt.cpp",
3637
"unpack_std.cpp",
3738
"unpack_var.cpp",
3839
"view_to_reshape.cpp",

core/lowering/passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ target_sources(${lib_name}
2020
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
2121
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
2222
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
23+
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_rsqrt.cpp"
2324
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2425
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2526
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3333
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3434
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3535
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
36+
void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph);
3637
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph);
3738
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3839
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/unpack_rsqrt.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackRsqrt(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string rsqrt_pattern = R"IR(
12+
graph(%1):
13+
%out: Tensor = aten::rsqrt(%1)
14+
return (%out))IR";
15+
std::string unpacked_pattern = R"IR(
16+
graph(%1):
17+
%intermediate: Tensor = aten::sqrt(%1)
18+
%out: Tensor = aten::reciprocal(%intermediate)
19+
return (%out))IR";
20+
21+
torch::jit::SubgraphRewriter rsqrt_rewriter;
22+
rsqrt_rewriter.RegisterRewritePattern(rsqrt_pattern, unpacked_pattern);
23+
rsqrt_rewriter.runOnGraph(graph);
24+
LOG_GRAPH("Post unpack rsqrt: " << *graph);
25+
}
26+
27+
} // namespace passes
28+
} // namespace lowering
29+
} // namespace core
30+
} // namespace torch_tensorrt

tests/core/lowering/test_unpack_reduce_ops.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,45 @@ TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) {
204204
ASSERT_TRUE(
205205
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
206206
}
207+
208+
TEST(LoweringPasses, UnpackRsqrtLowersCorrectly) {
209+
const auto graph = R"IR(
210+
graph(%x.1 : Tensor):
211+
%2 : Tensor = aten::rsqrt(%x.1)
212+
return (%2))IR";
213+
214+
// Make range [0.01, 1.01] to ensure positives / avoid NaN with negative sqrt
215+
auto in = at::rand({2, 3, 5, 7}, {at::kCUDA}) + 0.01;
216+
217+
auto g = std::make_shared<torch::jit::Graph>();
218+
torch::jit::parseIR(graph, g.get());
219+
220+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
221+
torch_tensorrt::core::lowering::passes::UnpackRsqrt(g);
222+
torch::jit::EliminateCommonSubexpression(g);
223+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
224+
225+
ASSERT_TRUE(
226+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
227+
}
228+
229+
TEST(LoweringPasses, UnpackRsqrtIntLowersCorrectly) {
230+
const auto graph = R"IR(
231+
graph(%x.1 : Tensor):
232+
%2 : Tensor = aten::rsqrt(%x.1)
233+
return (%2))IR";
234+
235+
// Make range of ints [1, 10]
236+
auto in = at::randint(1, 11, {2, 3, 5, 7}, {at::kCUDA});
237+
238+
auto g = std::make_shared<torch::jit::Graph>();
239+
torch::jit::parseIR(graph, g.get());
240+
241+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
242+
torch_tensorrt::core::lowering::passes::UnpackRsqrt(g);
243+
torch::jit::EliminateCommonSubexpression(g);
244+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
245+
246+
ASSERT_TRUE(
247+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
248+
}

0 commit comments

Comments
 (0)