@@ -115,7 +115,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
115
115
pytorch_nodes.push_back (n);
116
116
prev_non_tensor_outputs = containNonTensorOutputs (n);
117
117
} else {
118
- // If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes . Construct a
118
+ // If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes . Construct a
119
119
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
120
120
if (!pytorch_nodes.empty ()) {
121
121
new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
@@ -132,6 +132,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
132
132
new_seg_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
133
133
}
134
134
}
135
+
135
136
return std::move (new_seg_blocks);
136
137
}
137
138
@@ -159,6 +160,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
159
160
}
160
161
}
161
162
163
+ // For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block that has/produces it.
162
164
for (auto & use : usage_counts) {
163
165
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
164
166
if (segmented_blocks[i].contain_raw_value (use.first )) {
@@ -167,6 +169,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
167
169
}
168
170
}
169
171
172
+
170
173
std::unordered_set<int > updated_segments;
171
174
for (auto & use : usage_counts) {
172
175
auto use_info = use.second ;
@@ -178,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
178
181
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
179
182
// TRTorch doesn't support non-tensor inputs for a module.
180
183
auto to_inject_blocks = segmentBlocksWithNonTensorInputs (segmented_blocks[first_torch_id]);
181
- segmented_blocks.erase (segmented_blocks.begin () + first_torch_id);
182
- segmented_blocks.insert (
183
- segmented_blocks.begin () + first_torch_id, to_inject_blocks.begin (), to_inject_blocks.end ());
184
+ auto next_iter = segmented_blocks_list.erase (idx_to_iter[first_torch_id]);
185
+ segmented_blocks_list.insert (next_iter, to_inject_blocks.begin (), to_inject_blocks.end ());
184
186
updated_segments.insert (first_torch_id);
185
187
}
186
188
}
@@ -314,6 +316,7 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
314
316
segmented_blocks.emplace_back (SegmentedBlock::kTorch , std::vector<torch::jit::Node*>{n});
315
317
continue ;
316
318
} else if (n->kind () == torch::jit::prim::Loop) {
319
+
317
320
if (!pytorch_nodes.empty ()) {
318
321
segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
319
322
pytorch_nodes.clear ();
@@ -347,19 +350,15 @@ std::vector<SegmentedBlock> Partition(
347
350
const PartitionInfo& partition_info) {
348
351
LOG_DEBUG (partition_info);
349
352
// segment lowering global graph into blocks
350
- LOG_DEBUG (" Partitioning graph into PyTorch and TensorRT segmented blocks" );
351
353
std::vector<SegmentedBlock> segmented_blocks = segment_graph (block, partition_info);
352
354
353
355
// resolve nonTensor inputs/outputs
354
- LOG_DEBUG (" Resolving non-tensor type inputs/outputs (eg: int/float types)" );
355
356
resolveNonTensorInputs (segmented_blocks);
356
357
357
358
// register input/output torch::jit::Value for segmented graphs
358
- LOG_DEBUG (" Registering input/outputs for segmented blocks" );
359
359
registerSegmentsOutputs (segmented_blocks, block);
360
360
361
361
// run shape analysis on each segmented block
362
- LOG_DEBUG (" Running shape analysis for all the segmented blocks" );
363
362
runShapeAnalysis (segmented_blocks, input_ivalues_map);
364
363
365
364
return segmented_blocks;
0 commit comments