-
Notifications
You must be signed in to change notification settings - Fork 364
Automatic Fallback #406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic Fallback #406
Changes from 47 commits
848335e
123f026
bbd3835
1ca13d8
0d28164
55e0510
f4c29b4
6d3064a
100b090
8b7919f
46950bb
c0ea3a9
d90a300
da09e4b
6147d4f
54e407e
4e32eff
ec2bbf2
77b4dc7
459a9b9
cfc68ce
965a67a
3cebe97
664ccbd
c8656ce
1e68899
0a0e922
6e96289
fb1a299
b3589c5
24c3a22
ee536b6
57002ab
1447bd5
569d011
6d826d3
2840281
3d39d7c
116b001
3a72dc3
824b555
d73dc42
d4b7ad0
437670e
f722035
c1934c1
de3ba23
c67d8f6
20543c6
58cb53e
e491bb5
4a318a2
80b1038
5110480
dde0216
ff89059
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,37 +21,24 @@ | |
|
||
#include "core/conversion/conversion.h" | ||
#include "core/lowering/lowering.h" | ||
#include "core/partitioning/partitioning.h" | ||
#include "core/runtime/runtime.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
|
||
c10::FunctionSchema GenerateGraphSchema( | ||
torch::jit::script::Module mod, | ||
std::string method_name, | ||
std::shared_ptr<torch::jit::Graph>& g) { | ||
std::vector<c10::Argument> args; | ||
for (auto in : g->inputs()) { | ||
args.push_back(c10::Argument(in->debugName(), in->type())); | ||
} | ||
|
||
std::vector<c10::Argument> returns; | ||
for (auto out : g->outputs()) { | ||
returns.push_back(c10::Argument(out->debugName(), out->type())); | ||
} | ||
|
||
return c10::FunctionSchema(method_name, method_name, args, returns); | ||
} | ||
|
||
void AddEngineToGraph( | ||
torch::jit::script::Module mod, | ||
std::shared_ptr<torch::jit::Graph>& g, | ||
const std::string& serialized_engine) { | ||
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine); | ||
const std::string& serialized_engine, | ||
int engine_id = 0) { | ||
auto engine_ptr = | ||
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + std::to_string(engine_id), serialized_engine); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does engine_id just need to be unique or do we use the ids else where? If they are just unique we should use the pointer trick to get something that is likely to be unique, therefore we dont really need to worry about conflicts |
||
// Get required metadata about the engine out | ||
auto num_io = engine_ptr->num_io; | ||
auto name = engine_ptr->name; | ||
|
||
//.. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be removed. |
||
// Add the engine as an attribute of the module, this will let the engine be | ||
// serialized and deserialized | ||
mod.register_attribute( | ||
|
@@ -108,17 +95,19 @@ void AddEngineToGraph( | |
g->block()->appendNode(unpack_node); | ||
|
||
// If there are multiple output tensors from TensorRT we wrap them in a tuple | ||
// to return | ||
if (unpack_node->outputs().size() > 1) { | ||
// to return, convert to tuple only when we only have 1 segmented graph | ||
if (!engine_id && unpack_node->outputs().size() > 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the case where we have multiple TRT engines never have engine_id 0? We should not this if so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I refactored this function by adding a default argument indicating whether we have fallback or not. |
||
// Creates prim::TupleConstruct(<output tensors>) using outputs of the | ||
// unpack node | ||
auto return_tuple_node = g->createTuple(unpack_node->outputs()); | ||
g->block()->appendNode(return_tuple_node); | ||
// Set the output as the produced tuple | ||
g->registerOutput(return_tuple_node->outputs()[0]); | ||
} else { | ||
// Set the output as the sole output tensor | ||
g->registerOutput(unpack_node->outputs()[0]); | ||
// if fallback is enabled, multiple outputs will be registered | ||
for (size_t i = 0; i < unpack_node->outputs().size(); ++i) { | ||
g->registerOutput(unpack_node->outputs()[i]); | ||
} | ||
} | ||
|
||
LOG_DEBUG(*g << "(AddEngineToGraph)\n"); | ||
|
@@ -142,6 +131,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: | |
|
||
auto convert_cfg = std::move(cfg.convert_info); | ||
auto g = graph_and_parameters.first; | ||
|
||
auto params = graph_and_parameters.second; | ||
auto named_params = conversion::get_named_params(g->inputs(), params); | ||
|
||
|
@@ -151,7 +141,111 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: | |
return std::move(engine); | ||
} | ||
|
||
void AddSegmentedBlockToGraph( | ||
std::shared_ptr<torch::jit::Graph>& g, | ||
partitioning::SegmentedBlock& seg, | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) { | ||
// old_to_new_g contains: original global graph value => new global graph value, | ||
// mini_to_new_g: mini graph value -> new graph value | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g; | ||
size_t input_idx = 0; | ||
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) { | ||
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { | ||
auto self = g->insertInput(0, "self_1"); | ||
self->setType(seg.inputs()[0]->type()); | ||
} | ||
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0]; | ||
} | ||
|
||
for (auto& raw_input : seg.raw_inputs()) { | ||
if (old_to_new_g.count(raw_input)) { | ||
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input]; | ||
} | ||
} | ||
|
||
for (const auto n : seg.nodes()) { | ||
util::cloneNode(n, g, mini_to_new_g); | ||
} | ||
|
||
// original graph value => new global graph value | ||
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) { | ||
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]]; | ||
} | ||
|
||
return; | ||
} | ||
|
||
torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Module& mod, CompileSpec cfg) { | ||
// TODO: Should be doing a functional transform but need PR #31978 | ||
// [jit] More robust mangling | ||
// torch::jit::script::Module new_mod = mod.clone(); | ||
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt"); | ||
std::vector<std::shared_ptr<torch::jit::Graph>> graphs; | ||
for (const torch::jit::script::Method& method : mod.get_methods()) { | ||
// Don't convert hidden methods | ||
if (method.name().rfind("_", 0)) { | ||
auto new_g = std::make_shared<torch::jit::Graph>(); | ||
auto graph_and_parameters = lowering::Lower(mod, method.name()); | ||
|
||
auto g = graph_and_parameters.first; | ||
auto params = graph_and_parameters.second; | ||
auto named_params = conversion::get_named_params(g->inputs(), params); | ||
auto convert_cfg = std::move(cfg.convert_info); | ||
LOG_INFO(*g << "(LoweringGraph)\n"); | ||
|
||
// segment the graph and convert segmented TensorRT block | ||
auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info); | ||
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) { | ||
return mod; | ||
bowang007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
int trt_engine_id = 1; | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g; | ||
// add global graph's input to old_to_new_g mapping | ||
for (auto input : g->inputs()) { | ||
util::getOrAddInputForValue(input, new_g, old_to_new_g); | ||
} | ||
for (auto& seg_block : segmented_blocks) { | ||
LOG_INFO(*g << "(MiniGraphInSegmentedBlock)\n"); | ||
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { | ||
std::vector<ir::InputRange> input_ranges; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does Dynamic shape work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. Currently we haven't considered the case when we have dynamic shapes in shape analysis. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thats probably higher priority than loops then, I think since we have unrolling that can be enabled. Also I think its pretty achievable in the time we have |
||
for (auto& shape : seg_block.in_shape()) { | ||
input_ranges.push_back(ir::InputRange(shape)); | ||
} | ||
// update the input ranges for each segments | ||
convert_cfg.input_ranges = input_ranges; | ||
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params); | ||
auto temp_g = std::make_shared<torch::jit::Graph>(); | ||
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++); | ||
|
||
seg_block.update_graph(temp_g); | ||
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); | ||
} else { | ||
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); | ||
} | ||
} | ||
|
||
for (auto& output : g->outputs()) { | ||
new_g->registerOutput(old_to_new_g[output]); | ||
} | ||
|
||
LOG_INFO(*new_g << "(FallbackGraph)\n"); | ||
|
||
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); | ||
auto schema = util::GenerateGraphSchema(new_method->name(), new_g); | ||
new_mod.type()->addMethod(new_method); | ||
new_method->setSchema(schema); | ||
} | ||
} | ||
|
||
return new_mod; | ||
} | ||
|
||
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) { | ||
// TODO: not sure how to deal with duplicated code here, so just cut out a branch temporally | ||
if (cfg.partition_info.enabled) { | ||
return CompileGraphWithFallback(mod, cfg); | ||
} | ||
// TODO: Should be doing a functional transform but need PR #31978 | ||
// [jit] More robust mangling | ||
// torch::jit::script::Module new_mod = mod.clone(); | ||
|
@@ -164,7 +258,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C | |
auto new_g = std::make_shared<torch::jit::Graph>(); | ||
AddEngineToGraph(new_mod, new_g, engine); | ||
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g); | ||
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g); | ||
auto schema = util::GenerateGraphSchema(new_method->name(), new_g); | ||
new_mod.type()->addMethod(new_method); | ||
new_method->setSchema(schema); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
config_setting( | ||
name = "use_pre_cxx11_abi", | ||
values = { | ||
"define": "abi=pre_cxx11_abi", | ||
} | ||
) | ||
|
||
cc_library( | ||
name = "ir", | ||
hdrs = [ | ||
"ir.h" | ||
], | ||
srcs = [ | ||
"InputRange.cpp", | ||
], | ||
deps = [ | ||
"@tensorrt//:nvinfer", | ||
"//core/util:prelude", | ||
] + select({ | ||
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], | ||
"//conditions:default": ["@libtorch//:libtorch"], | ||
}), | ||
) | ||
|
||
load("@rules_pkg//:pkg.bzl", "pkg_tar") | ||
|
||
pkg_tar( | ||
name = "include", | ||
package_dir = "core/ir/", | ||
srcs = [ | ||
"ir.h", | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about methods for more descriptive names in the future