Skip to content

Commit ad07645

Browse files
committed
feat(//core/lowering): additional logging in module fallback
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent b96087b commit ad07645

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

Diff for: core/lowering/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cc_library(
1313
"drop_unused_nodes.cpp",
1414
"lowering.cpp",
1515
"register_trt_placeholder_ops.cpp",
16+
"LowerInfo.cpp"
1617
],
1718
hdrs = [
1819
"lowering.h",

Diff for: core/lowering/LowerInfo.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <iostream>
2+
#include <sstream>
3+
#include <utility>
4+
5+
#include "core/lowering/lowering.h"
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace lowering {
10+
11+
std::ostream& operator<<(std::ostream& os, const LowerInfo& l) {
12+
os << "Settings requested for Lowering:" << std::endl;
13+
os << " Forced Fallback Modules: [" << std::endl;
14+
for (auto i : l.forced_fallback_modules) {
15+
os << " " << i << std::endl;
16+
}
17+
os << " ]";
18+
return os;
19+
}
20+
21+
} // namespace lowering
22+
} // namespace core
23+
} // namespace trtorch

Diff for: py/trtorch/_compiler.py

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
5959
"force_fallback_ops": [
6060
"aten::max_pool2d" # List of specific ops to require running in PyTorch
6161
],
62+
"force_fallback_modules": [
63+
"mypymod.mytorchmod" # List of specific torch modules to require running in PyTorch
64+
],
6265
"min_block_size": 3 # Minimum number of ops an engine must incapsulate to be run in TensorRT
6366
}
6467
}

0 commit comments

Comments
 (0)