Skip to content

Commit db5d290

Browse files
Fix unmangle_cls_name for variable length mangled tags (#1454)
1 parent a1498ea commit db5d290

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

core/lowering/passes/module_fallback.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ std::string unmangle_cls_name(const std::string& name) {
1919

2020
std::size_t mangle_pos = unmangled.find("___torch_mangle_");
2121
if (mangle_pos != std::string::npos) {
22-
unmangled.erase(mangle_pos, 21);
22+
std::size_t dot_pos = unmangled.find(".", mangle_pos);
23+
TORCH_CHECK(dot_pos != std::string::npos, "Expected to find '.' after '___torch_mangle_' in name: ", unmangled);
24+
unmangled.erase(mangle_pos, dot_pos - mangle_pos + 1);
2325
}
24-
2526
return unmangled;
2627
}
2728

core/lowering/passes/passes.h

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::st
4646
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
4747
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4848

49+
// utility functions exposed for testing
50+
std::string unmangle_cls_name(const std::string& name);
51+
4952
} // namespace passes
5053
} // namespace lowering
5154
} // namespace core

tests/core/lowering/test_module_fallback_passes.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,17 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
126126
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
127127
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99));
128128
}
129+
130+
TEST(Lowering, UnmangleClsName) {
131+
EXPECT_EQ(
132+
"foo.Bar", torch_tensorrt::core::lowering::passes::unmangle_cls_name("__torch__.foo.___torch_mangle_605.Bar"));
133+
EXPECT_EQ(
134+
"torch.nn.modules.conv.Conv2d",
135+
torch_tensorrt::core::lowering::passes::unmangle_cls_name(
136+
"__torch__.torch.nn.modules.conv.___torch_mangle_5697.Conv2d"));
137+
EXPECT_EQ(
138+
"custom_models.ModuleFallbackMain",
139+
torch_tensorrt::core::lowering::passes::unmangle_cls_name("__torch__.custom_models.ModuleFallbackMain"));
140+
EXPECT_THROW(
141+
torch_tensorrt::core::lowering::passes::unmangle_cls_name("__torch__.foo.___torch_mangle_605"), std::exception);
142+
}

0 commit comments

Comments
 (0)