@@ -68,6 +68,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
68
68
q.pop ();
69
69
auto node = cur_val->node ();
70
70
if (node->kind () != torch::jit::prim::Constant && !visited.count (node)) {
71
+ visited.insert (node);
71
72
stk.push_back (node);
72
73
for (auto input : node->inputs ()) {
73
74
if (!isTensorOrTensorList (input)) {
@@ -89,14 +90,14 @@ std::vector<torch::jit::Node*> getOutputNodes(
89
90
std::unordered_set<torch::jit::Node*> visited;
90
91
q.push (value);
91
92
92
- // top-down order traveling
93
+ // top-down order traversing
93
94
while (!q.empty ()) {
94
95
auto cur_val = q.front ();
95
96
q.pop ();
96
97
for (auto use : cur_val->uses ()) {
97
98
auto node = use.user ;
98
99
// use node must be in seg_block_nodes
99
- if (seg_block_nodes.count (node) != 0 && !visited.count (node)) {
100
+ if (seg_block_nodes.count (node) && !visited.count (node)) {
100
101
stk.push_back (node);
101
102
visited.insert (node);
102
103
// travel its' all outputs
@@ -109,10 +110,41 @@ std::vector<torch::jit::Node*> getOutputNodes(
109
110
}
110
111
}
111
112
112
- // top-down order and we don't need reverse it
113
+ // top-down order and we don't need to reverse it
113
114
return stk;
114
115
}
115
116
117
+ void getDirtyNodes (
118
+ std::unordered_set<torch::jit::Node*>& dirty_nodes,
119
+ const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
120
+ std::queue<torch::jit::Node*> q;
121
+ for (auto & node : dirty_nodes) {
122
+ q.push (node);
123
+ }
124
+ dirty_nodes.clear ();
125
+
126
+ while (!q.empty ()) {
127
+ auto cur_node = q.front ();
128
+ q.pop ();
129
+ if (!dirty_nodes.count (cur_node) && seg_block_nodes.count (cur_node)) {
130
+ dirty_nodes.insert (cur_node);
131
+ for (auto input : cur_node->inputs ()) {
132
+ if (!isTensorOrTensorList (input)) {
133
+ q.push (input->node ());
134
+ }
135
+ }
136
+ for (auto output : cur_node->outputs ()) {
137
+ if (!isTensorOrTensorList (output)) {
138
+ for (auto use : output->uses ()) {
139
+ auto node = use.user ;
140
+ q.push (node);
141
+ }
142
+ }
143
+ }
144
+ }
145
+ }
146
+ }
147
+
116
148
std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock> segmentBlocksWithTensorListInputs (
117
149
SegmentedBlock& seg_block,
118
150
const std::unordered_map<torch::jit::Value*, SegmentedBlock>& tensorlist_inputs) {
@@ -163,25 +195,29 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
163
195
} else {
164
196
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
165
197
std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
166
- std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes (dependency_nodes. begin (), dependency_nodes. end ()) ;
198
+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
167
199
168
- bool prev_non_tensor_outputs = false ;
200
+ // take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use
201
+ // dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor
202
+ // values consumed by dirty nodes
203
+ std::unordered_set<torch::jit::Node*> dirty_nodes;
204
+ const std::unordered_set<torch::jit::Node*> seg_block_nodes (
205
+ seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
206
+
207
+ for (auto n : seg_block.raw_nodes ()) {
208
+ if (containTargetInputs (n, nontensor_inputs_set)) {
209
+ dirty_nodes.insert (n);
210
+ }
211
+ }
212
+ getDirtyNodes (dirty_nodes, seg_block_nodes);
169
213
for (auto n : seg_block.raw_nodes ()) {
170
- // Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
171
- // In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
172
- // SegmentedBlock.
173
- if (containTargetInputs (n, nontensor_inputs_set) || prev_non_tensor_outputs) {
174
- // If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
175
- // TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
214
+ if (dirty_nodes.count (n)) {
176
215
if (!tensorrt_nodes.empty ()) {
177
216
new_seg_blocks.emplace_back (new_seg_blocks.size (), SegmentedBlock::kTensorRT , tensorrt_nodes);
178
217
tensorrt_nodes.clear ();
179
218
}
180
219
pytorch_nodes.push_back (n);
181
- prev_non_tensor_outputs = containNonTensorOutputs (n);
182
220
} else {
183
- // If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
184
- // Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
185
221
if (!pytorch_nodes.empty ()) {
186
222
new_seg_blocks.emplace_back (new_seg_blocks.size (), SegmentedBlock::kTorch , pytorch_nodes);
187
223
pytorch_nodes.clear ();
@@ -190,7 +226,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
190
226
}
191
227
}
192
228
193
- // Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly.
229
+ // Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly.
194
230
if (!tensorrt_nodes.empty ()) {
195
231
new_seg_blocks.emplace_back (new_seg_blocks.size (), SegmentedBlock::kTensorRT , tensorrt_nodes);
196
232
} else {
0 commit comments