diff --git a/core/lowering/passes/module_fallback.cpp b/core/lowering/passes/module_fallback.cpp index 415b385634..b40eaaf235 100644 --- a/core/lowering/passes/module_fallback.cpp +++ b/core/lowering/passes/module_fallback.cpp @@ -19,9 +19,10 @@ std::string unmangle_cls_name(const std::string& name) { std::size_t mangle_pos = unmangled.find("___torch_mangle_"); if (mangle_pos != std::string::npos) { - unmangled.erase(mangle_pos, 21); + std::size_t dot_pos = unmangled.find(".", mangle_pos); + TORCH_CHECK(dot_pos != std::string::npos, "Expected to find '.' after '___torch_mangle_' in name: ", unmangled); + unmangled.erase(mangle_pos, dot_pos - mangle_pos + 1); } - return unmangled; } diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 713894e0c7..557d682022 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -46,6 +46,9 @@ void UnpackAndCastNumToTensor(std::shared_ptr& graph, std::st void UnpackAndCastFull(std::shared_ptr& graph, std::string target_device_name); void ReplaceScalarImplicit(std::shared_ptr& graph); +// utility functions exposed for testing +std::string unmangle_cls_name(const std::string& name); + } // namespace passes } // namespace lowering } // namespace core diff --git a/tests/core/lowering/test_module_fallback_passes.cpp b/tests/core/lowering/test_module_fallback_passes.cpp index 5f4ac5f0c2..e94cb7807a 100644 --- a/tests/core/lowering/test_module_fallback_passes.cpp +++ b/tests/core/lowering/test_module_fallback_passes.cpp @@ -126,3 +126,17 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) { auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results, 0.99)); } + +TEST(Lowering, UnmangleClsName) { + EXPECT_EQ( + "foo.Bar", torch_tensorrt::core::lowering::passes::unmangle_cls_name("__torch__.foo.___torch_mangle_605.Bar")); + EXPECT_EQ( + "torch.nn.modules.conv.Conv2d", + torch_tensorrt::core::lowering::passes::unmangle_cls_name( + "__torch__.torch.nn.modules.conv.___torch_mangle_5697.Conv2d")); + EXPECT_EQ( + "custom_models.ModuleFallbackMain", + torch_tensorrt::core::lowering::passes::unmangle_cls_name("__torch__.custom_models.ModuleFallbackMain")); + EXPECT_THROW( + torch_tensorrt::core::lowering::passes::unmangle_cls_name("__torch__.foo.___torch_mangle_605"), std::exception); +}