Skip to content

Commit 1aa492f

Browse files
committed
chore: Debugging commit
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 01c6952 commit 1aa492f

File tree

5 files changed

+38
-32
lines changed

5 files changed

+38
-32
lines changed

Diff for: core/lowering/passes/module_fallback.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void NotateModuleForFallback(
3939
if (n->kind() == torch::jit::prim::GetAttr) {
4040
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
4141
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
42-
LOG_DEBUG(
42+
LOG_GRAPH(
4343
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
4444
<< " (" << cls_name << ")]");
4545
auto uses = n->output(0)->uses();
@@ -58,7 +58,7 @@ void NotateModuleForFallback(
5858
}
5959

6060
if (changed_mod) {
61-
LOG_DEBUG("Notated graph: " << *g);
61+
LOG_GRAPH("Notated graph: " << *g);
6262
}
6363

6464
for (const auto sub_mod : mod.named_children()) {
@@ -106,10 +106,10 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
106106
}
107107
}
108108

109-
LOG_DEBUG("After marking operations for torch fallback: " << *g);
109+
LOG_GRAPH("After marking operations for torch fallback: " << *g);
110110
}
111111

112112
} // namespace passes
113113
} // namespace lowering
114114
} // namespace core
115-
} // namespace trtorch
115+
} // namespace trtorch

Diff for: core/lowering/passes/unpack_var.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
4242
torch::jit::SubgraphRewriter var_rewriter;
4343
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
4444
var_rewriter.runOnGraph(graph);
45-
LOG_DEBUG("Post unpack var: " << *graph);
45+
LOG_GRAPH("Post unpack var: " << *graph);
4646
}
4747

4848
} // namespace passes

Diff for: core/partitioning/partitioning.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,19 @@ std::vector<SegmentedBlock> Partition(
347347
const PartitionInfo& partition_info) {
348348
LOG_DEBUG(partition_info);
349349
// segment lowering global graph into blocks
350+
LOG_DEBUG("Partitioning graph into PyTorch and TensorRT segmented blocks");
350351
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);
351352

352353
// resolve nonTensor inputs/outputs
354+
LOG_DEBUG("Resolving non-tensor type inputs/outputs (eg: int/float types)");
353355
resolveNonTensorInputs(segmented_blocks);
354356

355357
// register input/output torch::jit::Value for segmented graphs
358+
LOG_DEBUG("Registering input/outputs for segmented blocks");
356359
registerSegmentsOutputs(segmented_blocks, block);
357360

358361
// run shape analysis on each segmented block
362+
LOG_DEBUG("Running shape analysis for all the segmented blocks");
359363
runShapeAnalysis(segmented_blocks, input_ivalues_map);
360364

361365
return segmented_blocks;

Diff for: core/partitioning/shape_analysis.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ void runShapeAnalysis(
108108
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
109109
// register every segment's input shape, and it's running output IValues
110110
for (auto& seg_block : segmented_blocks) {
111+
LOG_DEBUG("Segmented graph: " << *seg_block.g());
111112
torch::jit::ConstantPooling(seg_block.g());
112113
getSegmentsOutputByRunning(seg_block, ivalues_maps);
114+
LOG_DEBUG("=================");
113115
}
114116
return;
115117
}

Diff for: tests/core/partitioning/test_loop_fallback.cpp

+27-27
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,30 @@ TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) {
3333
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
3434
}
3535

36-
TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
37-
torch::jit::script::Module mod;
38-
try {
39-
mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
40-
} catch (const c10::Error& e) {
41-
std::cerr << "error loading the model\n";
42-
return;
43-
}
44-
45-
const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
46-
std::vector<torch::jit::IValue> jit_inputs_ivalues;
47-
std::vector<torch::jit::IValue> trt_inputs_ivalues;
48-
for (auto in_shape : input_shapes) {
49-
auto in = at::randint(5, in_shape, {at::kCUDA});
50-
jit_inputs_ivalues.push_back(in.clone());
51-
trt_inputs_ivalues.push_back(in.clone());
52-
}
53-
54-
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
55-
trtorch::core::CompileSpec cfg(input_ranges);
56-
cfg.partition_info.enabled = true;
57-
58-
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
59-
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
60-
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
61-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
62-
}
36+
// TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
37+
// torch::jit::script::Module mod;
38+
// try {
39+
// mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
40+
// } catch (const c10::Error& e) {
41+
// std::cerr << "error loading the model\n";
42+
// return;
43+
// }
44+
//
45+
// const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
46+
// std::vector<torch::jit::IValue> jit_inputs_ivalues;
47+
// std::vector<torch::jit::IValue> trt_inputs_ivalues;
48+
// for (auto in_shape : input_shapes) {
49+
// auto in = at::randint(5, in_shape, {at::kCUDA});
50+
// jit_inputs_ivalues.push_back(in.clone());
51+
// trt_inputs_ivalues.push_back(in.clone());
52+
// }
53+
//
54+
// std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
55+
// trtorch::core::CompileSpec cfg(input_ranges);
56+
// cfg.partition_info.enabled = true;
57+
//
58+
// auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
59+
// auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
60+
// auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
61+
// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
62+
// }

0 commit comments

Comments
 (0)