Skip to content

Commit 2ee2a84

Browse files
committed
feat: show pytorch code of unsupported operators
Signed-off-by: lamhoangtung <[email protected]>
1 parent bdaacf1 commit 2ee2a84

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

Diff for: core/conversion/conversion.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
428428
return engine;
429429
}
430430

431-
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
432-
std::set<std::string> unsupported_ops;
431+
std::set<std::pair<std::string, std::string>> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
432+
std::set<std::pair<std::string, std::string>> unsupported_ops;
433433
for (const auto n : b->nodes()) {
434434
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
435435
auto schema = n->maybeSchema();
@@ -438,7 +438,9 @@ std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
438438
"Unable to get schema for Node " << util::node_info(n) << " (conversion.VerifyCoverterSupportForBlock)");
439439
std::stringstream ss;
440440
ss << *schema;
441-
unsupported_ops.insert(ss.str());
441+
std::string pytorch_code = trtorch::core::util::GetPyTorchSourceCode(n);
442+
auto current_unsupported_op = std::make_pair(ss.str(), pytorch_code);
443+
unsupported_ops.insert(current_unsupported_op);
442444
}
443445
for (const auto sub_b : n->blocks()) {
444446
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
@@ -480,7 +482,8 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
480482
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:"
481483
<< std::endl;
482484
for (auto s : unsupported_ops) {
483-
unsupported_msg << " - " << s << std::endl;
485+
unsupported_msg << " - " << s.first << std::endl;
486+
unsupported_msg << " Related PyTorch code:" << std::endl << s.second << std::endl;
484487
}
485488
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
486489
<< std::endl;

Diff for: core/util/jit_util.h

+5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ inline c10::FunctionSchema GenerateGraphSchema(std::string method_name, std::sha
4747
return c10::FunctionSchema(method_name, method_name, args, returns);
4848
}
4949

50+
inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
51+
std::string source_code = n->sourceRange().str();
52+
return source_code;
53+
}
54+
5055
} // namespace util
5156
} // namespace core
5257
} // namespace trtorch

0 commit comments

Comments
 (0)