Skip to content

Commit 114969b

Browse files
authored
Merge pull request #447 from NVIDIA/bowa_primif
feat: support prim::If in automatic fallback
2 parents 7e54e17 + 0b49965 commit 114969b

13 files changed

+313
-102
lines changed

Diff for: core/compiler.cpp

+135-42
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1313
#include "torch/csrc/jit/ir/ir.h"
14+
#include "torch/csrc/jit/ir/ir_views.h"
1415
#include "torch/csrc/jit/passes/graph_fuser.h"
1516
#include "torch/csrc/jit/passes/loop_unrolling.h"
1617
#include "torch/csrc/jit/passes/lower_graph.h"
@@ -173,10 +174,131 @@ void AddSegmentedBlockToGraph(
173174
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
174175
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
175176
}
177+
size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
178+
for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
179+
if (!old_to_new_g.count(seg.raw_inputs()[i])) {
180+
old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
181+
}
182+
}
176183

177184
return;
178185
}
179186

187+
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
188+
GraphAndMapping;
189+
190+
void AddIfBlockToGraph(
191+
std::shared_ptr<torch::jit::Graph>& new_g,
192+
torch::jit::Node* if_node,
193+
const std::vector<GraphAndMapping>& graph_and_mappings,
194+
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
195+
torch::jit::IfView if_view(if_node);
196+
197+
// create a new if node in new_g and add corresponding inputs
198+
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
199+
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));
200+
201+
// iterate over all blocks and add them to new created prim::If
202+
for (auto graph_and_mapping : graph_and_mappings) {
203+
auto new_if_block = new_if->addBlock();
204+
auto cur_block_graph = graph_and_mapping.first;
205+
auto cur_block_mapping = graph_and_mapping.second;
206+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
207+
for (auto& i : cur_block_mapping) {
208+
// for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
209+
// it's mini graph's input
210+
if (old_to_new_g.count(i.first)) {
211+
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
212+
}
213+
}
214+
215+
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
216+
new_if_block->cloneFrom(cur_block_graph->block(), env);
217+
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
218+
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
219+
auto self = new_g->insertInput(0, "self_1");
220+
self->setType(cur_block_graph->inputs()[0]->type());
221+
}
222+
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
223+
}
224+
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
225+
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
226+
new_if_block->eraseInput(i);
227+
}
228+
}
229+
for (auto ov : if_view.outputs()) {
230+
auto no = new_if->addOutput();
231+
old_to_new_g[ov] = no;
232+
no->copyMetadata(ov);
233+
}
234+
return;
235+
}
236+
237+
GraphAndMapping ConstructFallbackGraph(
238+
torch::jit::script::Module& new_mod,
239+
torch::jit::Block* block,
240+
std::unordered_map<torch::jit::Value*, torch::jit::IValue> input_ivalues_map,
241+
CompileSpec cfg,
242+
conversion::GraphParams named_params) {
243+
auto convert_cfg = cfg.convert_info;
244+
auto partition_info = cfg.partition_info;
245+
246+
auto new_g = std::make_shared<torch::jit::Graph>();
247+
248+
auto segmented_blocks = partitioning::Partition(block, input_ivalues_map, partition_info);
249+
250+
// the mapping from lowering graph => fallback global graph
251+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
252+
for (auto input : block->inputs()) {
253+
util::getOrAddInputForValue(input, new_g, old_to_new_g);
254+
}
255+
256+
for (auto& seg_block : segmented_blocks) {
257+
LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n");
258+
std::ostringstream trt_engine_id;
259+
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
260+
261+
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
262+
std::vector<ir::Input> inputs;
263+
for (auto& shape : seg_block.in_shape()) {
264+
inputs.push_back(ir::Input(shape));
265+
}
266+
// update the input ranges for each segments
267+
convert_cfg.inputs = inputs;
268+
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
269+
auto temp_g = std::make_shared<torch::jit::Graph>();
270+
auto device_spec = convert_cfg.engine_settings.device;
271+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
272+
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
273+
274+
seg_block.update_graph(temp_g);
275+
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
276+
} else {
277+
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
278+
auto if_node = seg_block.raw_nodes()[0];
279+
280+
// convert the 2 blocks in prim::if and get the converted graph with mappings
281+
std::vector<GraphAndMapping> graph_and_mappings;
282+
for (auto cur_block : if_node->blocks()) {
283+
graph_and_mappings.push_back(
284+
ConstructFallbackGraph(new_mod, cur_block, input_ivalues_map, cfg, named_params));
285+
}
286+
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
287+
288+
} else {
289+
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
290+
}
291+
}
292+
}
293+
294+
for (auto& output : block->outputs()) {
295+
if (old_to_new_g.count(output)) {
296+
new_g->registerOutput(old_to_new_g[output]);
297+
}
298+
}
299+
return {new_g, old_to_new_g};
300+
}
301+
180302
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) {
181303
// TODO: Should be doing a functional transform but need PR #31978
182304
// [jit] More robust mangling
@@ -192,53 +314,24 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
192314
auto g = graph_and_parameters.first;
193315
auto params = graph_and_parameters.second;
194316
auto named_params = conversion::get_named_params(g->inputs(), params);
195-
auto convert_cfg = std::move(cfg.convert_info);
196-
LOG_INFO(*g << "(LoweringGraph)\n");
317+
LOG_INFO("(LoweredGraph)\n" << *g);
197318

198-
// segment the graph and convert segmented TensorRT block
199-
auto segmented_blocks = partitioning::Partition(g, convert_cfg.inputs, cfg.partition_info);
200-
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
319+
std::unordered_map<torch::jit::Value*, ir::Input> inputs;
320+
for (size_t i = 0; i < g->inputs().size(); ++i) {
321+
inputs.insert({g->inputs()[i], cfg.convert_info.inputs[i]});
322+
}
323+
auto input_ivalues_map = partitioning::generateRandomInputs(inputs);
324+
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
325+
new_g = graph_and_mapping.first;
326+
LOG_INFO("(FallbackGraph)\n" << *new_g);
327+
328+
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
329+
// module
330+
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
201331
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
202332
return mod;
203333
}
204334

205-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
206-
// add global graph's input to old_to_new_g mapping
207-
for (auto input : g->inputs()) {
208-
util::getOrAddInputForValue(input, new_g, old_to_new_g);
209-
}
210-
for (auto& seg_block : segmented_blocks) {
211-
std::string cur_block_target =
212-
seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
213-
LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
214-
std::ostringstream trt_engine_id;
215-
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
216-
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
217-
std::vector<ir::Input> inputs;
218-
for (auto& shape : seg_block.in_shape()) {
219-
inputs.push_back(ir::Input(shape));
220-
}
221-
// update the input ranges for each segments
222-
convert_cfg.inputs = inputs;
223-
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
224-
auto temp_g = std::make_shared<torch::jit::Graph>();
225-
auto device_spec = convert_cfg.engine_settings.device;
226-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
227-
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
228-
229-
seg_block.update_graph(temp_g);
230-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
231-
} else {
232-
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
233-
}
234-
}
235-
236-
for (auto& output : g->outputs()) {
237-
new_g->registerOutput(old_to_new_g[output]);
238-
}
239-
240-
LOG_INFO(*new_g << "(FallbackGraph)\n");
241-
242335
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
243336
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
244337
new_mod.type()->addMethod(new_method);

Diff for: core/partitioning/SegmentedBlock.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace trtorch {
44
namespace core {
55
namespace partitioning {
66

7-
SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes)
7+
SegmentedBlock::SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
88
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
99
for (auto& node : nodes) {
1010
nodes_.push_back(node);

Diff for: core/partitioning/SegmentedBlock.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct SegmentedBlock {
2020

2121
SegmentedBlock() = default;
2222
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
23-
SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes);
23+
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
2424
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
2525

2626
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);

Diff for: core/partitioning/partitioning.cpp

+34-15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <queue>
44
#include "core/conversion/conversion.h"
55
#include "core/partitioning/shape_analysis.h"
6+
#include "torch/csrc/jit/passes/constant_pooling.h"
67
#include "torch/csrc/jit/passes/dead_code_elimination.h"
78

89
namespace trtorch {
@@ -85,8 +86,14 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
8586
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
8687
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
8788
if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) {
88-
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
89-
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
89+
// if current node is prim::If, just ensure that we have all required input in kTorch
90+
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
91+
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
92+
new_seg_blocks.push_back(seg_block);
93+
} else {
94+
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
95+
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
96+
}
9097
} else {
9198
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
9299
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
@@ -127,7 +134,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
127134
return std::move(new_seg_blocks);
128135
}
129136

130-
void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
137+
void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shared_ptr<torch::jit::Graph> g
131138
// create a list so we can insert SegmentedBlock without losing the iterators
132139
std::list<SegmentedBlock> segmented_blocks_list(segmented_blocks.begin(), segmented_blocks.end());
133140
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> idx_to_iter;
@@ -169,8 +176,10 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
169176
if (!updated_segments.count(first_torch_id)) {
170177
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
171178
// TRTorch doesn't support non-tensor inputs for a module.
172-
auto new_torch_block = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]).front();
173-
*idx_to_iter[first_torch_id] = new_torch_block;
179+
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
180+
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
181+
segmented_blocks.insert(
182+
segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
174183
updated_segments.insert(first_torch_id);
175184
}
176185
}
@@ -191,7 +200,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
191200
return;
192201
}
193202

194-
void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
203+
void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) {
195204
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
196205
std::set<torch::jit::Value*> input_values;
197206
for (auto& seg_block : segmented_blocks) {
@@ -200,7 +209,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
200209
}
201210
}
202211

203-
for (auto& graph_output : g->outputs()) {
212+
for (auto& graph_output : block->outputs()) {
204213
input_values.insert(graph_output);
205214
}
206215

@@ -249,12 +258,12 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, std::shared_ptr
249258
return;
250259
}
251260

252-
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, const PartitionInfo& partition_info) {
261+
std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) {
253262
auto min_block_size = partition_info.min_block_size;
254263
std::unordered_set<std::string> forced_fallback_operators(
255264
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
256265

257-
auto nodes = g->block()->nodes();
266+
auto nodes = block->nodes();
258267
std::vector<SegmentedBlock> segmented_blocks;
259268

260269
// segment the nodes
@@ -278,6 +287,16 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
278287
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
279288
}
280289
tensorrt_nodes.clear();
290+
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
291+
// we shouldn't inject node for this block in dependency analysis process
292+
if (n->kind() == torch::jit::prim::If) {
293+
if (!pytorch_nodes.empty()) {
294+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
295+
pytorch_nodes.clear();
296+
}
297+
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
298+
continue;
299+
}
281300
pytorch_nodes.push_back(n);
282301
}
283302
}
@@ -295,21 +314,21 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
295314
}
296315

297316
std::vector<SegmentedBlock> Partition(
298-
std::shared_ptr<torch::jit::Graph> g,
299-
std::vector<ir::Input>& inputs,
317+
torch::jit::Block* block,
318+
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
300319
const PartitionInfo& partition_info) {
301320
LOG_DEBUG(partition_info);
302321
// segment lowering global graph into blocks
303-
std::vector<SegmentedBlock> segmented_blocks = segment_graph(g, partition_info);
322+
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);
304323

305324
// resolve nonTensor inputs/outputs
306-
resolveNonTensorInputs(segmented_blocks, g);
325+
resolveNonTensorInputs(segmented_blocks);
307326

308327
// register input/output torch::jit::Value for segmented graphs
309-
registerSegmentsOutputs(segmented_blocks, g);
328+
registerSegmentsOutputs(segmented_blocks, block);
310329

311330
// run shape analysis on each segmented block
312-
runShapeAnalysis(segmented_blocks, inputs, g);
331+
runShapeAnalysis(segmented_blocks, input_ivalues_map);
313332

314333
return segmented_blocks;
315334
}

Diff for: core/partitioning/partitioning.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/ir/ir.h"
66
#include "core/partitioning/PartitionInfo.h"
77
#include "core/partitioning/SegmentedBlock.h"
8+
#include "core/partitioning/shape_analysis.h"
89
#include "core/util/prelude.h"
910
#include "torch/csrc/jit/ir/ir.h"
1011

@@ -14,13 +15,13 @@ namespace partitioning {
1415

1516
typedef std::vector<SegmentedBlock> PartitionedGraph;
1617

17-
PartitionedGraph segment_graph(std::shared_ptr<torch::jit::Graph> g, const PartitionInfo& partition_info);
18+
PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info);
1819

1920
std::vector<SegmentedBlock> Partition(
20-
std::shared_ptr<torch::jit::Graph> g,
21-
std::vector<ir::Input>& inputs,
21+
torch::jit::Block* block,
22+
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& input_ivalues_map,
2223
const PartitionInfo& partition_info);
2324

2425
} // namespace partitioning
2526
} // namespace core
26-
} // namespace trtorch
27+
} // namespace trtorch

0 commit comments

Comments
 (0)