Skip to content

Commit 46950bb

Browse files
committed
fix: support shape inference for add_, support non-tensor arguments for segmented graphs
Signed-off-by: Bo Wang <[email protected]>
1 parent 8b7919f commit 46950bb

File tree

7 files changed

+103
-41
lines changed

7 files changed

+103
-41
lines changed

Diff for: core/compiler.cpp

+17-12
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,18 @@ void AddEngineToGraph(
111111
g->block()->appendNode(unpack_node);
112112

113113
// If there are multiple output tensors from TensorRT we wrap them in a tuple
114-
// to return
115-
if (unpack_node->outputs().size() > 1) {
114+
// to return, convert to tuple only when we only have 1 segmented graph
115+
if (!engine_id && unpack_node->outputs().size() > 1) {
116116
// Creates prim::TupleConstruct(<output tensors>) using outputs of the
117117
// unpack node
118118
auto return_tuple_node = g->createTuple(unpack_node->outputs());
119119
g->block()->appendNode(return_tuple_node);
120120
// Set the output as the produced tuple
121121
g->registerOutput(return_tuple_node->outputs()[0]);
122122
} else {
123-
// Set the output as the sole output tensor
124-
g->registerOutput(unpack_node->outputs()[0]);
123+
for (int i = 0; i < unpack_node->outputs().size(); ++i) {
124+
g->registerOutput(unpack_node->outputs()[i]);
125+
}
125126
}
126127

127128
LOG_DEBUG(*g << "(AddEngineToGraph)\n");
@@ -159,32 +160,35 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
159160

160161
void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitioning::SegmentedBlock &seg,
161162
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new_g) {
162-
//old_to_new_g contains: original_graph value => new graph value, mini_graph value -> new graph value, new graph value -> mini_graph value
163+
//old_to_new_g contains: original global graph value => new global graph value,
164+
//mini_to_new_g: mini graph value -> new graph value
165+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
163166
size_t input_idx = 0;
164167
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
165168
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
166169
auto self = g->insertInput(0, "self_1");
167170
self->setType(seg.inputs()[0]->type());
168171
}
169-
old_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
172+
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
170173
}
171174

175+
172176
for (auto &raw_input : seg.raw_inputs()) {
173177
if (old_to_new_g.count(raw_input)) {
174-
old_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
178+
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
175179
}
176180
}
177181

178182
for (const auto n : seg.nodes()) {
179-
partitioning::cloneNode(n, g, old_to_new_g);
183+
partitioning::cloneNode(n, g, mini_to_new_g);
180184
}
181185

182186
// original graph value => new global graph value
183187
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
184-
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
188+
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
185189
}
186190

187-
// LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
191+
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
188192
return;
189193
}
190194

@@ -199,21 +203,21 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
199203
if (method.name().rfind("_", 0)) {
200204
auto new_g = std::make_shared<torch::jit::Graph>();
201205
auto graph_and_parameters = lowering::Lower(mod, method.name());
206+
LOG_INFO(*(method.graph()) << "Original grpah\n");
202207

203208
auto g = graph_and_parameters.first;
204209
auto params = graph_and_parameters.second;
205210
auto named_params = conversion::get_named_params(g->inputs(), params);
206211
auto convert_cfg = std::move(cfg.convert_info);
207212
LOG_INFO(*g << "(CompileGraph)\n");
208213

209-
210214
// segment the graph and convert segmented TensorRT block
211215
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges, convert_cfg.engine_settings.torch_fallback);
212216
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
213217
return mod;
214218
}
215219

216-
int trt_engine_id = 0;
220+
int trt_engine_id = 1;
217221
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
218222
for (auto &seg_block : segmented_blocks) {
219223
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
@@ -225,6 +229,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
225229
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
226230
auto temp_g = std::make_shared<torch::jit::Graph>();
227231
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);
232+
228233
seg_block.update_graph(temp_g);
229234
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
230235
} else {

Diff for: core/lowering/passes/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ cc_library(
2626
"unpack_batch_norm.cpp",
2727
"unpack_log_softmax.cpp",
2828
"op_aliasing.cpp",
29-
"silu_to_sigmoid_multiplication.cpp"
29+
"silu_to_sigmoid_multiplication.cpp",
30+
"remove_inplace_add.cpp"
3031
],
3132
deps = [
3233
"//core/util:prelude",

Diff for: core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
2222
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
2323
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
24+
void RemoveInplaceAdd(std::shared_ptr<torch::jit::Graph>& graph);
2425

2526
} // namespace passes
2627
} // namespace lowering

Diff for: core/lowering/passes/remove_inplace_add.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void RemoveInplaceAdd(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string inplace_add_pattern = R"IR(
12+
graph(%self, %other, %1):
13+
%out = aten::add_(%self, %other, %1)
14+
return (%out))IR";
15+
std::string normal_add_pattern = R"IR(
16+
graph(%self, %other, %1):
17+
%out = aten::add(%self, %other, %1)
18+
return (%out))IR";
19+
20+
torch::jit::SubgraphRewriter remove_inplace_add;
21+
remove_inplace_add.RegisterRewritePattern(inplace_add_pattern, normal_add_pattern);
22+
remove_inplace_add.runOnGraph(graph);
23+
24+
LOG_GRAPH("Post remove inplace add: " << *graph);
25+
}
26+
27+
} // namespace passes
28+
} // namespace lowering
29+
} // namespace core
30+
} // namespace trtorch

Diff for: core/partitioning/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ cc_library(
1717
],
1818
deps = [
1919
"//core/conversion",
20-
"//core/util:prelude"
20+
"//core/util:prelude",
21+
"//core/lowering"
2122
] + select({
2223
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2324
"//conditions:default": ["@libtorch//:libtorch"],

Diff for: core/partitioning/partitioning.cpp

+45-24
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "core/util/prelude.h"
33
#include "torch/csrc/jit/api/module.h"
44
#include "core/util/prelude.h"
5+
#include "core/lowering/passes/passes.h"
6+
57

68

79
namespace trtorch {
@@ -20,9 +22,9 @@ torch::jit::Value* getOrAddInputForValue(torch::jit::Value* old_value, std::shar
2022
}
2123
auto new_value = graph->block()->addInput();
2224
old_to_new[old_value] = new_value;
25+
new_value->copyMetadata(old_value);
2326
// mapping from new graph input Values to original graph values
2427
old_to_new[new_value] = old_value;
25-
new_value->copyMetadata(old_value);
2628
return new_value;
2729
} else {
2830
return old_to_new[old_value];
@@ -40,7 +42,6 @@ torch::jit::Node* cloneNode(torch::jit::Node* node, std::shared_ptr<torch::jit::
4042
auto no = new_node->outputs()[i];
4143
old_to_new[oo] = no;
4244
}
43-
4445
return new_node;
4546
}
4647

@@ -58,10 +59,13 @@ c10::FunctionSchema getFunctionSchema(std::string method_name, std::shared_ptr<t
5859
return c10::FunctionSchema(method_name, method_name, args, returns);
5960
}
6061

61-
void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, nvinfer1::Dims> &input_shape_map) {
62+
void registerSegmentInOutIValues(SegmentedBlock &seg_block, std::unordered_map<torch::jit::Value*, torch::jit::IValue> &ivalues_maps) {
6263
// create a module to run the graph
6364
auto g = seg_block.g();
6465
auto copy_g = g->copy();
66+
lowering::passes::RemoveInplaceAdd(copy_g);
67+
68+
// create tuple for multiple outputs
6569
if (seg_block.raw_outputs().size() > 1) {
6670
auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs()));
6771
for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) {
@@ -84,46 +88,60 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<tor
8488

8589
// set inputs ivalues
8690
for (auto &input : seg_block.raw_inputs()) {
87-
std::vector<int64_t> shape;
88-
nvinfer1::Dims cur_shape = input_shape_map[input];
89-
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
90-
auto in = at::randint(5, shape, {at::kCUDA});
91-
jit_inputs_ivalues.push_back(in.clone());
91+
if (!ivalues_maps.count(input)) {
92+
std::cerr << "could find graph input ivalues\n";
93+
}
94+
if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
95+
jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor());
96+
} else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) {
97+
jit_inputs_ivalues.push_back(ivalues_maps[input].toInt());
98+
}
9299
}
93100

94-
std::vector<at::Tensor> jit_results;
101+
std::vector<torch::jit::IValue> jit_results;
95102
torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues);
96-
if (jit_results_ivalues.isTensor()) {
97-
jit_results.push_back(jit_results_ivalues.toTensor());
98-
} else {
103+
if (jit_results_ivalues.isTuple()) {
99104
auto results = jit_results_ivalues.toTuple()->elements();
100105
for (auto r : results) {
101-
jit_results.push_back(r.toTensor());
106+
jit_results.push_back(r);
102107
}
108+
} else {
109+
jit_results.push_back(jit_results_ivalues);
103110
}
104111

105112
size_t idx = 0;
106113
for (auto &output : seg_block.raw_outputs()) {
107-
input_shape_map[output] = util::toDims(jit_results[idx++].sizes());
114+
ivalues_maps[output] = jit_results[idx++];
108115
}
109116

117+
// set input shape for each segmented block so we wil use it in conversion process
110118
std::vector<nvinfer1::Dims> input_shape;
111119
for (auto &i : seg_block.raw_inputs()) {
112-
input_shape.push_back(input_shape_map[i]);
120+
if (ivalues_maps[i].isTensor()) {
121+
input_shape.push_back(util::toDims(ivalues_maps[i].toTensor().sizes()));
122+
}
113123
}
114124

115125
seg_block.register_inshape(input_shape);
116126
}
117127

118-
std::vector<nvinfer1::Dims> extractNvinfer1Dims(std::vector<conversion::InputRange>& input_ranges) {
119-
std::vector<nvinfer1::Dims> res;
128+
129+
std::vector<torch::jit::IValue> generateRandomInputs(std::vector<conversion::InputRange>& input_ranges) {
130+
std::vector<torch::jit::IValue> random_inputs;
120131
for (auto &input_range : input_ranges) {
121-
res.push_back(input_range.input_shape);
132+
auto cur_shape = input_range.input_shape;
133+
std::vector<int64_t> shape;
134+
shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims);
135+
auto in = at::randint(5, shape, {at::kCUDA});
136+
random_inputs.push_back(in.clone());
137+
printf("is tensor: %d\n", random_inputs.back().isTensor());
122138
}
123-
return res;
139+
return random_inputs;
124140
}
125141

142+
126143
void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
144+
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
127145
std::set<torch::jit::Value*> input_values;
128146
for (auto &seg_block : segmented_blocks) {
129147
seg_block.registerInputs();
@@ -176,6 +194,7 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
176194

177195
for (const auto n : nodes) {
178196
if (n->kind() == torch::jit::prim::Constant) continue;
197+
179198
std::string node_string(n->kind().toQualString());
180199

181200
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
@@ -186,19 +205,21 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
186205
}
187206
}
188207
merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
189-
if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
208+
if (!pytorch_nodes.empty()) {
209+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
210+
}
190211

191212
registerSegmentsInputsOutputs(segmented_blocks, g);
192213

193-
std::vector<nvinfer1::Dims> graph_inputs_shape = extractNvinfer1Dims(input_ranges);
194-
std::unordered_map<torch::jit::Value*, nvinfer1::Dims> input_shape_map;
214+
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
195215

216+
std::vector<torch::jit::IValue> random_inputs = generateRandomInputs(input_ranges);
196217
for (size_t i = 0; i < g->inputs().size(); ++i) {
197-
input_shape_map[g->inputs()[i]] = graph_inputs_shape[i];
218+
ivalues_maps[g->inputs()[i]] = random_inputs[i];
198219
}
199220

200221
for (auto &seg_block : segmented_blocks) {
201-
registerSegmentInOutShape(seg_block, input_shape_map);
222+
registerSegmentInOutIValues(seg_block, ivalues_maps);
202223
}
203224

204225
return segmented_blocks;

Diff for: core/partitioning/partitioning.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ namespace trtorch {
1010
namespace core {
1111
namespace partitioning {
1212

13+
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* old_value, std::shared_ptr<torch::jit::Graph> &graph,
14+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new);
15+
1316
torch::jit::Node* cloneNode(torch::jit::Node* node, std::shared_ptr<torch::jit::Graph> &graph,
1417
std::unordered_map<torch::jit::Value*, torch::jit::Value*> &old_to_new);
1518

@@ -49,7 +52,6 @@ struct SegmentedBlock {
4952

5053
void registerOutput(torch::jit::Value* raw_input) {
5154
outputs_.push_back(raw_input);
52-
5355
g_->registerOutput(old_to_new_[raw_input]);
5456
}
5557

@@ -97,15 +99,16 @@ struct SegmentedBlock {
9799
return out_shape_;
98100
}
99101

100-
const std::shared_ptr<torch::jit::Graph>& g() const {
102+
std::shared_ptr<torch::jit::Graph>& g() {
101103
return g_;
102104
}
103105

106+
104107
void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
105108
g_ = new_g;
106109
}
107110

108-
private:
111+
// private:
109112
SegmentedBlockTarget target_;
110113
std::vector<nvinfer1::Dims> in_shape_;
111114
std::vector<nvinfer1::Dims> out_shape_;

0 commit comments

Comments
 (0)