|
3 | 3 | #include "core/lowering/passes/passes.h"
|
4 | 4 | #include "core/util/prelude.h"
|
5 | 5 | #include "torch/csrc/jit/api/module.h"
|
| 6 | +#include "torch/csrc/jit/ir/constants.h" |
6 | 7 |
|
7 | 8 | namespace trtorch {
|
8 | 9 | namespace core {
|
@@ -67,6 +68,7 @@ void registerSegmentInOutIValues(
|
67 | 68 | // create a module to run the graph
|
68 | 69 | auto g = seg_block.g();
|
69 | 70 | auto copy_g = g->copy();
|
| 71 | +// LOG_INFO(*copy_g << "(copy graph)\n"); |
70 | 72 |
|
71 | 73 | // create tuple for multiple outputs
|
72 | 74 | if (seg_block.raw_outputs().size() > 1) {
|
@@ -163,19 +165,53 @@ void registerSegmentsInputsOutputs(
|
163 | 165 | input_values.insert(graph_output);
|
164 | 166 | }
|
165 | 167 |
|
166 |
| - for (auto& mini_graph_input : input_values) { |
167 |
| - for (auto& seg_block : segmented_blocks) { |
| 168 | + // should be careful here because some in-place operations don't return any values |
| 169 | + for (auto& seg_block : segmented_blocks) { |
| 170 | + for (auto& mini_graph_input : input_values) { |
168 | 171 | if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
|
169 | 172 | seg_block.raw_inputs().end() &&
|
170 | 173 | seg_block.contain_raw_input(mini_graph_input)) {
|
171 | 174 | seg_block.registerOutput(mini_graph_input);
|
172 | 175 | }
|
173 | 176 | }
|
| 177 | + if (seg_block.raw_outputs().empty()) { |
| 178 | + seg_block.registerOutput(seg_block.raw_inputs()[0]); |
| 179 | + } |
174 | 180 | }
|
175 | 181 |
|
176 | 182 | return;
|
177 | 183 | }
|
178 | 184 |
|
| 185 | +void eraseNonTensorInputsOutputs( |
| 186 | + SegmentedBlock& seg_block, |
| 187 | + std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) { |
| 188 | + if (seg_block.target() == SegmentedBlock::kTorch) |
| 189 | + return; |
| 190 | + auto mini_graph = seg_block.g(); |
| 191 | + |
| 192 | + for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) { |
| 193 | + // erase this input and prepend a prim::Constant if it's not Tensor |
| 194 | + if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) && |
| 195 | + !seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) { |
| 196 | + auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]); |
| 197 | + seg_block.inputs()[i]->replaceAllUsesWith(new_val); |
| 198 | + seg_block.eraseInput(i); |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) { |
| 203 | + if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) && |
| 204 | + !seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) { |
| 205 | + seg_block.eraseOutput(i); |
| 206 | + } |
| 207 | + } |
| 208 | + |
| 209 | + // not sure to delete this block or just fallback to pytorch |
| 210 | + if (seg_block.raw_outputs().empty()) { |
| 211 | + seg_block.update_target(SegmentedBlock::kTorch); |
| 212 | + } |
| 213 | +} |
| 214 | + |
179 | 215 | void construct_segments(
|
180 | 216 | std::vector<torch::jit::Node*>& pytorch_nodes,
|
181 | 217 | std::vector<torch::jit::Node*>& tensorrt_nodes,
|
@@ -240,6 +276,7 @@ std::vector<SegmentedBlock> segment_graph(
|
240 | 276 | // register every segment's input shape, and it's running output Ivalues
|
241 | 277 | for (auto& seg_block : segmented_blocks) {
|
242 | 278 | registerSegmentInOutIValues(seg_block, ivalues_maps);
|
| 279 | + eraseNonTensorInputsOutputs(seg_block, ivalues_maps); |
243 | 280 | }
|
244 | 281 |
|
245 | 282 | return segmented_blocks;
|
|
0 commit comments