Skip to content

Commit 54e407e

Browse files
committed
feat: support Int/Bool and other constants' inputs/outputs for TensorRT segments
Signed-off-by: Bo Wang <[email protected]>
1 parent 6147d4f commit 54e407e

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ auto aten_registrations TRTORCH_UNUSED =
436436
if (args.at(n->input(0)).IValue()->isInt()) {
437437
auto a = args.at(n->input(0)).unwrapToInt();
438438
auto b = args.at(n->input(1)).unwrapToInt();
439-
return std::floor(a / b);
439+
return static_cast<int>(std::floor(a / b));
440440
} else if (args.at(n->input(0)).IValue()->isDouble()) {
441441
auto a = args.at(n->input(0)).unwrapToDouble();
442442
auto b = args.at(n->input(1)).unwrapToDouble();

Diff for: core/partitioning/partitioning.cpp

+39-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "core/lowering/passes/passes.h"
44
#include "core/util/prelude.h"
55
#include "torch/csrc/jit/api/module.h"
6+
#include "torch/csrc/jit/ir/constants.h"
67

78
namespace trtorch {
89
namespace core {
@@ -67,6 +68,7 @@ void registerSegmentInOutIValues(
6768
// create a module to run the graph
6869
auto g = seg_block.g();
6970
auto copy_g = g->copy();
71+
// LOG_INFO(*copy_g << "(copy graph)\n");
7072

7173
// create tuple for multiple outputs
7274
if (seg_block.raw_outputs().size() > 1) {
@@ -163,19 +165,53 @@ void registerSegmentsInputsOutputs(
163165
input_values.insert(graph_output);
164166
}
165167

166-
for (auto& mini_graph_input : input_values) {
167-
for (auto& seg_block : segmented_blocks) {
168+
// should be careful here because some in-place operations don't return any values
169+
for (auto& seg_block : segmented_blocks) {
170+
for (auto& mini_graph_input : input_values) {
168171
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
169172
seg_block.raw_inputs().end() &&
170173
seg_block.contain_raw_input(mini_graph_input)) {
171174
seg_block.registerOutput(mini_graph_input);
172175
}
173176
}
177+
if (seg_block.raw_outputs().empty()) {
178+
seg_block.registerOutput(seg_block.raw_inputs()[0]);
179+
}
174180
}
175181

176182
return;
177183
}
178184

185+
void eraseNonTensorInputsOutputs(
186+
SegmentedBlock& seg_block,
187+
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
188+
if (seg_block.target() == SegmentedBlock::kTorch)
189+
return;
190+
auto mini_graph = seg_block.g();
191+
192+
for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) {
193+
// erase this input and prepend a prim::Constant if it's not Tensor
194+
if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
195+
!seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
196+
auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]);
197+
seg_block.inputs()[i]->replaceAllUsesWith(new_val);
198+
seg_block.eraseInput(i);
199+
}
200+
}
201+
202+
for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) {
203+
if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
204+
!seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
205+
seg_block.eraseOutput(i);
206+
}
207+
}
208+
209+
// not sure to delete this block or just fallback to pytorch
210+
if (seg_block.raw_outputs().empty()) {
211+
seg_block.update_target(SegmentedBlock::kTorch);
212+
}
213+
}
214+
179215
void construct_segments(
180216
std::vector<torch::jit::Node*>& pytorch_nodes,
181217
std::vector<torch::jit::Node*>& tensorrt_nodes,
@@ -240,6 +276,7 @@ std::vector<SegmentedBlock> segment_graph(
240276
// register every segment's input shape, and it's running output Ivalues
241277
for (auto& seg_block : segmented_blocks) {
242278
registerSegmentInOutIValues(seg_block, ivalues_maps);
279+
eraseNonTensorInputsOutputs(seg_block, ivalues_maps);
243280
}
244281

245282
return segmented_blocks;

Diff for: core/partitioning/partitioning.h

+14
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,20 @@ struct SegmentedBlock {
6666
return g_->inputs();
6767
}
6868

69+
void eraseInput(size_t i) {
70+
inputs_.erase(inputs_.begin() + i);
71+
g_->eraseInput(i);
72+
}
73+
6974
c10::ArrayRef<torch::jit::Value*> outputs() {
7075
return g_->outputs();
7176
}
7277

78+
void eraseOutput(size_t i) {
79+
outputs_.erase(outputs_.begin() + i);
80+
g_->eraseOutput(i);
81+
}
82+
7383
const std::vector<torch::jit::Value*>& raw_inputs() const {
7484
return inputs_;
7585
}
@@ -102,6 +112,10 @@ struct SegmentedBlock {
102112
g_ = new_g;
103113
}
104114

115+
void update_target(SegmentedBlockTarget new_target) {
116+
target_ = new_target;
117+
}
118+
105119
private:
106120
SegmentedBlockTarget target_;
107121
std::vector<nvinfer1::Dims> in_shape_;

0 commit comments

Comments
 (0)