Skip to content

Commit c67d8f6

Browse files
committed
feat: support the case when the injected node is not supported in dependency analysis
Signed-off-by: Bo Wang <[email protected]>
1 parent de3ba23 commit c67d8f6

File tree

2 files changed

+75
-14
lines changed

2 files changed

+75
-14
lines changed

Diff for: core/compiler.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
274274
auto new_g = std::make_shared<torch::jit::Graph>();
275275
AddEngineToGraph(new_mod, new_g, engine);
276276
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
277-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
277+
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
278278
new_mod.type()->addMethod(new_method);
279279
new_method->setSchema(schema);
280280

Diff for: core/partitioning/partitioning.cpp

+74-13
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,46 @@ namespace trtorch {
99
namespace core {
1010
namespace partitioning {
1111

12-
inline bool isTensorOrTensorList(torch::jit::Value* val) {
13-
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
14-
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
15-
}
16-
1712
struct usage_info {
1813
int produce_id = -1;
1914
std::vector<int> torch_use_id;
2015
std::vector<int> tensorrt_use_id;
2116
};
2217

18+
inline bool isTensorOrTensorList(torch::jit::Value* val) {
19+
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
20+
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
21+
}
22+
23+
bool isAllNodesSupported(const std::vector<torch::jit::Node*>& nodes) {
24+
for (auto node : nodes) {
25+
if (!conversion::OpSupported(node)) {
26+
return false;
27+
}
28+
}
29+
return true;
30+
}
31+
32+
bool containNonTensorInputs(torch::jit::Node* n, const std::unordered_set<torch::jit::Value*>& target_inputs) {
33+
for (auto input : n->inputs()) {
34+
if (!isTensorOrTensorList(input) && target_inputs.count(input)) {
35+
return true;
36+
}
37+
}
38+
return false;
39+
}
40+
41+
bool containNonTensorOutputs(torch::jit::Node* n) {
42+
for (auto output : n->outputs()) {
43+
if (!isTensorOrTensorList(output)) {
44+
return true;
45+
}
46+
}
47+
return false;
48+
}
49+
2350
std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
24-
// using bfs to get the DAG dependency nodes for input value
51+
// use bfs to get the DAG dependency nodes for input value
2552
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
2653
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
2754
std::unordered_set<torch::jit::Node*> visited;
@@ -43,17 +70,50 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
4370
return stk;
4471
}
4572

46-
SegmentedBlock injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
73+
std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
4774
// reconstruct segmented_block if this block requires nonTensor input
4875
std::vector<torch::jit::Value*> nontensor_inputs;
4976
for (auto input : seg_block.raw_inputs()) {
5077
if (!isTensorOrTensorList(input)) {
5178
nontensor_inputs.push_back(input);
5279
}
5380
}
54-
std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes(nontensor_inputs);
55-
new_block_nodes.insert(new_block_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
56-
return std::move(SegmentedBlock(seg_block.target(), new_block_nodes));
81+
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
82+
83+
std::vector<SegmentedBlock> new_seg_blocks;
84+
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
85+
// one new block
86+
if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) {
87+
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
88+
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
89+
} else {
90+
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
91+
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
92+
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, dependency_nodes);
93+
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
94+
bool prev_non_tensor_outputs = false;
95+
for (auto n : seg_block.raw_nodes()) {
96+
// it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
97+
if (containNonTensorInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
98+
if (!tensorrt_nodes.empty()) {
99+
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
100+
}
101+
pytorch_nodes.push_back(n);
102+
prev_non_tensor_outputs = containNonTensorOutputs(n);
103+
} else {
104+
if (!pytorch_nodes.empty()) {
105+
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
106+
}
107+
tensorrt_nodes.push_back(n);
108+
}
109+
}
110+
if (!tensorrt_nodes.empty()) {
111+
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
112+
} else {
113+
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
114+
}
115+
}
116+
return std::move(new_seg_blocks);
57117
}
58118

59119
void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
@@ -80,16 +140,17 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
80140
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
81141
int first_torch_id = use_info.torch_use_id.front();
82142
if (!updated_segments.count(first_torch_id)) {
83-
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
143+
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]).front();
84144
segmented_blocks[first_torch_id] = new_torch_block;
85145
updated_segments.insert(first_torch_id);
86146
}
87147
} else {
88148
// KTensorRT segments always need to inject nodes for the nonTensor inputs
89149
for (int i : use_info.tensorrt_use_id) {
90150
if (!updated_segments.count(i)) {
91-
auto new_seg_block = injectNodesForNonTensorInputs(segmented_blocks[i]);
92-
segmented_blocks[i] = new_seg_block;
151+
auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[i]);
152+
segmented_blocks.erase(segmented_blocks.begin() + i);
153+
segmented_blocks.insert(segmented_blocks.begin() + i, to_inject_blocks.begin(), to_inject_blocks.end());
93154
updated_segments.insert(i);
94155
}
95156
}

0 commit comments

Comments
 (0)