Skip to content

Commit 3cc2dfb

Browse files
committed
fix: refactor the resegmentation for TensorRT segments in ResolveNonTensorInput
Signed-off-by: Bo Wang <[email protected]>
1 parent 10b55d4 commit 3cc2dfb

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

Diff for: core/partitioning/partitioning.cpp

100644100755
+51-15
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
6868
q.pop();
6969
auto node = cur_val->node();
7070
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
71+
visited.insert(node);
7172
stk.push_back(node);
7273
for (auto input : node->inputs()) {
7374
if (!isTensorOrTensorList(input)) {
@@ -89,14 +90,14 @@ std::vector<torch::jit::Node*> getOutputNodes(
8990
std::unordered_set<torch::jit::Node*> visited;
9091
q.push(value);
9192

92-
// top-down order traveling
93+
// top-down order traversing
9394
while (!q.empty()) {
9495
auto cur_val = q.front();
9596
q.pop();
9697
for (auto use : cur_val->uses()) {
9798
auto node = use.user;
9899
// use node must be in seg_block_nodes
99-
if (seg_block_nodes.count(node) != 0 && !visited.count(node)) {
100+
if (seg_block_nodes.count(node) && !visited.count(node)) {
100101
stk.push_back(node);
101102
visited.insert(node);
102103
// travel its' all outputs
@@ -109,10 +110,41 @@ std::vector<torch::jit::Node*> getOutputNodes(
109110
}
110111
}
111112

112-
// top-down order and we don't need reverse it
113+
// top-down order and we don't need to reverse it
113114
return stk;
114115
}
115116

117+
void getDirtyNodes(
118+
std::unordered_set<torch::jit::Node*>& dirty_nodes,
119+
const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
120+
std::queue<torch::jit::Node*> q;
121+
for (auto& node : dirty_nodes) {
122+
q.push(node);
123+
}
124+
dirty_nodes.clear();
125+
126+
while (!q.empty()) {
127+
auto cur_node = q.front();
128+
q.pop();
129+
if (!dirty_nodes.count(cur_node) && seg_block_nodes.count(cur_node)) {
130+
dirty_nodes.insert(cur_node);
131+
for (auto input : cur_node->inputs()) {
132+
if (!isTensorOrTensorList(input)) {
133+
q.push(input->node());
134+
}
135+
}
136+
for (auto output : cur_node->outputs()) {
137+
if (!isTensorOrTensorList(output)) {
138+
for (auto use : output->uses()) {
139+
auto node = use.user;
140+
q.push(node);
141+
}
142+
}
143+
}
144+
}
145+
}
146+
}
147+
116148
std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock> segmentBlocksWithTensorListInputs(
117149
SegmentedBlock& seg_block,
118150
const std::unordered_map<torch::jit::Value*, SegmentedBlock>& tensorlist_inputs) {
@@ -163,25 +195,29 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
163195
} else {
164196
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
165197
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
166-
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end());
198+
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
167199

168-
bool prev_non_tensor_outputs = false;
200+
// take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use
201+
// dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor
202+
// values consumed by dirty nodes
203+
std::unordered_set<torch::jit::Node*> dirty_nodes;
204+
const std::unordered_set<torch::jit::Node*> seg_block_nodes(
205+
seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
206+
207+
for (auto n : seg_block.raw_nodes()) {
208+
if (containTargetInputs(n, nontensor_inputs_set)) {
209+
dirty_nodes.insert(n);
210+
}
211+
}
212+
getDirtyNodes(dirty_nodes, seg_block_nodes);
169213
for (auto n : seg_block.raw_nodes()) {
170-
// Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
171-
// In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
172-
// SegmentedBlock.
173-
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
174-
// If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
175-
// TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
214+
if (dirty_nodes.count(n)) {
176215
if (!tensorrt_nodes.empty()) {
177216
new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes);
178217
tensorrt_nodes.clear();
179218
}
180219
pytorch_nodes.push_back(n);
181-
prev_non_tensor_outputs = containNonTensorOutputs(n);
182220
} else {
183-
// If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
184-
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
185221
if (!pytorch_nodes.empty()) {
186222
new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes);
187223
pytorch_nodes.clear();
@@ -190,7 +226,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
190226
}
191227
}
192228

193-
// Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly.
229+
// Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly.
194230
if (!tensorrt_nodes.empty()) {
195231
new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes);
196232
} else {

Diff for: core/partitioning/shape_analysis.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ void getSegmentsOutputByRunning(
8686
jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar());
8787
} else if (input->type()->kind() == torch::jit::TypeKind::DictType) {
8888
jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict());
89+
} else if (input->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
90+
jit_inputs_ivalues.push_back(ivalues_maps[input].toDevice());
8991
} else {
9092
TORCHTRT_THROW_ERROR(
9193
"Expected to find type " << input->type()->str() << " for value " << input->debugName()

0 commit comments

Comments
 (0)