Skip to content

Commit 8927e77

Browse files
committed
feat(//core/partitioning): Improved logging and code org for the
segmentation step of partitioning Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 17e0e8a commit 8927e77

12 files changed

+249
-43
lines changed

Diff for: core/partitioning/SegmentedBlock.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ namespace trtorch {
44
namespace core {
55
namespace partitioning {
66

7+
SegmentedBlock::SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
8+
: id_(id), target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
9+
for (auto& node : nodes) {
10+
nodes_.push_back(node);
11+
appendNode(node);
12+
}
13+
}
14+
715
SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
816
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
917
for (auto& node : nodes) {
@@ -62,6 +70,18 @@ torch::jit::Node* SegmentedBlock::cloneNode(torch::jit::Node* node) {
6270
return new_node;
6371
}
6472

73+
std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b) {
74+
os << "Segment Block @" << b.id_ << ":" << std::endl;
75+
os << " Target: " << b.target_ << std::endl;
76+
os << " Graph: " << *b.g_ << std::endl;
77+
return os;
78+
}
79+
80+
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t) {
81+
os << SegmentedBlock::target_to_str(t) << std::endl;
82+
return os;
83+
}
84+
6585
} // namespace partitioning
6686
} // namespace core
6787
} // namespace trtorch

Diff for: core/partitioning/SegmentedBlock.h

+17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <vector>
4+
#include <ostream>
45

56
#include "NvInfer.h"
67
#include "core/ir/ir.h"
@@ -18,10 +19,21 @@ struct SegmentedBlock {
1819
kTensorRT,
1920
};
2021

22+
static std::string target_to_str(SegmentedBlockTarget t) {
23+
if (t == SegmentedBlockTarget::kTorch) {
24+
return "Torch";
25+
} else {
26+
return "TensorRT";
27+
}
28+
}
29+
30+
using BlockID = uint64_t;
31+
2132
SegmentedBlock() = default;
2233
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
2334
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
2435
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
36+
SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
2537

2638
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
2739
torch::jit::Node* cloneNode(torch::jit::Node* node);
@@ -74,7 +86,10 @@ struct SegmentedBlock {
7486
return target_;
7587
}
7688

89+
friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);
90+
7791
private:
92+
BlockID id_;
7893
SegmentedBlockTarget target_;
7994
std::vector<ir::Input> in_shape_;
8095
std::vector<torch::jit::Value*> inputs_;
@@ -84,6 +99,8 @@ struct SegmentedBlock {
8499
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
85100
};
86101

102+
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);
103+
87104
} // namespace partitioning
88105
} // namespace core
89106
} // namespace trtorch

Diff for: core/partitioning/partitioning.cpp

+84-34
Original file line numberDiff line numberDiff line change
@@ -275,81 +275,120 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
275275
return compile_to_trt;
276276
}
277277

278-
std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
278+
bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set<std::string>& torch_ops) {
279+
// If the op is not supported by the conversion phase it should run in PyTorch
280+
if (!conversion::OpSupported(n)) {
281+
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
282+
return false;
283+
}
284+
285+
// If the user specifies the op to run in Torch it should run in PyTorch
286+
if (torch_ops.find(n->kind().toQualString()) != torch_ops.end()) {
287+
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
288+
return false;
289+
}
290+
291+
// If the user specifies the module containing this op to run in torch it should run in PyTorch
292+
const auto to_compile_sym = c10::Symbol::attr("to_compile");
293+
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
294+
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
295+
return false;
296+
}
297+
298+
LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n));
299+
return true;
300+
}
301+
302+
void finalize_block(PartitionedGraph& g, SegmentedBlock::SegmentedBlockTarget kind, std::vector<torch::jit::Node*>& nodes) {
303+
SegmentedBlock::BlockID b_id= g.size();
304+
LOG_DEBUG("Finalizing in progress " << SegmentedBlock::target_to_str(kind) << " block");
305+
g.emplace_back(b_id, kind, nodes);
306+
nodes.clear();
307+
LOG_DEBUG(g.back());
308+
}
309+
310+
PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
279311
auto min_block_size = partition_info.min_block_size;
280-
std::unordered_set<std::string> forced_fallback_operators(
312+
std::unordered_set<std::string> forced_fallback_ops(
281313
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
282314

283315
auto nodes = block->nodes();
284-
std::vector<SegmentedBlock> segmented_blocks;
316+
PartitionedGraph segmented_blocks;
285317

286318
// segment the nodes
287-
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
319+
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
288320
for (const auto n : nodes) {
321+
// Skip constant nodes as they are resources for both kinds of modules
289322
if (n->kind() == torch::jit::prim::Constant) {
290323
continue;
291324
}
292325

293-
std::string node_string(n->kind().toQualString());
294-
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
295-
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) &&
296-
(!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
297-
tensorrt_nodes.push_back(n);
298-
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
299-
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
300-
pytorch_nodes.clear();
326+
if (should_run_in_trt(n, forced_fallback_ops)) {
327+
in_prog_trt_blk_nodes.push_back(n);
328+
329+
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
330+
// block then segment and reset the active PyTorch block
331+
if (in_prog_trt_blk_nodes.size() >= min_block_size && !in_prog_pyt_blk_nodes.empty()) {
332+
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
301333
}
302334
} else {
303-
if (tensorrt_nodes.size() >= min_block_size) {
304-
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
335+
// If there is an active TRT block that is valid segment and reset the active TRT block
336+
// otherwise add it to the active PyTorch block and reset
337+
if (in_prog_trt_blk_nodes.size() >= min_block_size) {
338+
finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
305339
} else {
306-
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
340+
LOG_DEBUG("In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
341+
in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
307342
}
308-
tensorrt_nodes.clear();
343+
in_prog_trt_blk_nodes.clear();
309344
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
310345
// we shouldn't inject node for this block in dependency analysis process
311346
if (n->kind() == torch::jit::prim::If) {
312-
if (!pytorch_nodes.empty()) {
313-
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
314-
pytorch_nodes.clear();
347+
LOG_DEBUG("Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional");
348+
if (!in_prog_pyt_blk_nodes.empty()) {
349+
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
315350
}
316-
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
351+
auto cond_node = std::vector<torch::jit::Node*>{n};
352+
finalize_block(segmented_blocks, SegmentedBlock::kTorch, cond_node);
317353
continue;
318354
} else if (n->kind() == torch::jit::prim::Loop) {
319-
if (!pytorch_nodes.empty()) {
320-
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
321-
pytorch_nodes.clear();
355+
if (!in_prog_pyt_blk_nodes.empty()) {
356+
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
322357
}
323358
if (checkLoopEvaluatable(n)) {
324-
tensorrt_nodes.push_back(n);
359+
in_prog_trt_blk_nodes.push_back(n);
325360
} else {
326-
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
361+
auto loop_node = std::vector<torch::jit::Node*>{n};
362+
finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node);
327363
}
328364
continue;
329365
}
330-
pytorch_nodes.push_back(n);
366+
in_prog_pyt_blk_nodes.push_back(n);
331367
}
332368
}
333369

334370
// if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num <
335371
// min_block_size
336-
if (!pytorch_nodes.empty()) {
337-
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
338-
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
339-
} else {
340-
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
372+
if (in_prog_trt_blk_nodes.size() >= min_block_size) {
373+
finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes);
374+
}
375+
376+
if (!in_prog_pyt_blk_nodes.empty()) {
377+
in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
378+
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
341379
}
342380

343381
return std::move(segmented_blocks);
344382
}
345383

346-
std::vector<SegmentedBlock> Partition(
384+
PartitionedGraph Partition(
347385
torch::jit::Block* block,
348386
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
349387
const PartitionInfo& partition_info) {
350388
LOG_DEBUG(partition_info);
351389
// segment lowering global graph into blocks
352-
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);
390+
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
391+
PartitionedGraph segmented_blocks = segment_graph(block, partition_info);
353392

354393
// resolve nonTensor inputs/outputs
355394
resolveNonTensorInputs(segmented_blocks);
@@ -358,11 +397,22 @@ std::vector<SegmentedBlock> Partition(
358397
registerSegmentsOutputs(segmented_blocks, block);
359398

360399
// run shape analysis on each segmented block
361-
runShapeAnalysis(segmented_blocks, input_ivalues_map);
400+
runShapeAnalysis(segmented_blocks, input_ivalues_map, at::kFloat);
401+
402+
LOG_INFO(segmented_blocks);
362403

363404
return segmented_blocks;
364405
}
365406

407+
std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g) {
408+
os << "Partitioned Graph: [";
409+
for (auto b : g) {
410+
os << b;
411+
}
412+
os << "]";
413+
return os;
414+
}
415+
366416
} // namespace partitioning
367417
} // namespace core
368418
} // namespace trtorch

Diff for: core/partitioning/partitioning.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <vector>
4+
#include <iostream>
45

56
#include "core/ir/ir.h"
67
#include "core/partitioning/PartitionInfo.h"
@@ -17,11 +18,13 @@ typedef std::vector<SegmentedBlock> PartitionedGraph;
1718

1819
PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info);
1920

20-
std::vector<SegmentedBlock> Partition(
21+
PartitionedGraph Partition(
2122
torch::jit::Block* block,
2223
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
2324
const PartitionInfo& partition_info);
2425

26+
std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g);
27+
2528
} // namespace partitioning
2629
} // namespace core
2730
} // namespace trtorch

Diff for: core/util/BUILD

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ cc_library(
2727
hdrs = [
2828
"jit_util.h",
2929
],
30+
srcs = [
31+
"jit_util.cpp"
32+
],
3033
deps = select({
3134
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
3235
"//conditions:default": ["@libtorch//:libtorch"],

Diff for: core/util/jit_util.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include "core/util/jit_util.h"
2+
3+
namespace trtorch {
4+
namespace core {
5+
namespace util {
6+
7+
c10::optional<at::ScalarType> getBlockFirstCalcDType(const std::shared_ptr<torch::jit::Block>& b) {
8+
auto ns = b->nodes();
9+
10+
c10::optional<at::ScalarType> dtype = {};
11+
12+
// For each node check the inputs to find a prim:Constant, which will provide a static tensor.
13+
// Use that tensor to determine operating dtype for the first calculation in the block
14+
for (auto n : ns) {
15+
if (n->kind() == torch::jit::prim::Constant) {
16+
// Not really helpful to evaluate typing for constants
17+
continue;
18+
}
19+
20+
auto ins = n->inputs();
21+
auto outs = n->outputs();
22+
23+
bool outputs_tensor = false;
24+
for (auto o : outs) {
25+
if (o->type() == c10::TensorType::get()) {
26+
outputs_tensor = true;
27+
}
28+
}
29+
30+
if (outputs_tensor) {
31+
// If all input tensors are block inputs then this node will not give us useful type info so move to the next one
32+
std::unordered_set<torch::jit::Value*> node_input_set = {ins.begin(), ins.end()};
33+
34+
bool all_n_ins_are_b_ins = true;
35+
for (auto b_in : b->inputs()) {
36+
if (node_input_set.find(b_in) == node_input_set.end()) {
37+
all_n_ins_are_b_ins = false;
38+
}
39+
}
40+
41+
if (all_n_ins_are_b_ins) {
42+
continue;
43+
}
44+
45+
46+
// If node outputs a Tensor it might be a result of tensor calcuation so check to see
47+
// if any inputs to the calculation can give us hints
48+
c10::optional<torch::jit::Node*> const_tensor_n = {};
49+
50+
// Backtrace to constants which will immediately give us the Tensor type if possible
51+
for (auto in : ins) {
52+
if (in->type() == c10::TensorType::get()) {
53+
if (in->node()->kind() == torch::jit::prim::Constant) {
54+
auto const_ival = in->node()->get(c10::Symbol::attr("value"));
55+
dtype = {const_ival.value().toTensor().scalar_type()};
56+
goto exit_first_calc_dtype;
57+
}
58+
}
59+
}
60+
}
61+
}
62+
63+
exit_first_calc_dtype:
64+
return dtype;
65+
}
66+
67+
} // namespace util
68+
} // namespace core
69+
} // namespace trtorch

Diff for: core/util/jit_util.h

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ inline std::string GetPyTorchSourceCode(const torch::jit::Node* n) {
5252
return source_code;
5353
}
5454

55+
c10::optional<at::ScalarType> getBlockFirstCalcDType(const std::shared_ptr<torch::jit::Block>& b);
56+
5557
} // namespace util
5658
} // namespace core
5759
} // namespace trtorch

Diff for: core/util/logging/TRTorchLogger.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ namespace {
125125

126126
TRTorchLogger& get_global_logger() {
127127
#ifndef NDEBUG
128-
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kDEBUG, true);
128+
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ", LogLevel::kGRAPH, true);
129129
#else
130130
static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false);
131131
#endif

Diff for: tests/core/partitioning/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ cc_test(
8181
)
8282

8383
test_suite(
84-
name = "partitioning_test",
84+
name = "partitioning_tests",
8585
tests = [
8686
":test_segmentation",
8787
":test_shape_analysis",

0 commit comments

Comments
 (0)