Skip to content

Centralizing Partitioning State #1263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 22, 2022
220 changes: 46 additions & 174 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

#include "torch/csrc/jit/frontend/function_schema_parser.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/ir_views.h"
#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/loop_unrolling.h"
#include "torch/csrc/jit/passes/lower_graph.h"
Expand Down Expand Up @@ -128,179 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
return conversion::VerifyConverterSupportForBlock(g->block());
}

void AddSegmentedBlockToGraph(
std::shared_ptr<torch::jit::Graph>& g,
partitioning::SegmentedBlock& seg,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
// old_to_new_g contains: original global graph value => new global graph value,
// mini_to_new_g: mini graph value -> new graph value
std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
size_t input_idx = 0;
if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) {
if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = g->insertInput(0, "self_1");
self->setType(seg.inputs()[0]->type());
}
mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0];
}

for (auto& raw_input : seg.raw_inputs()) {
if (old_to_new_g.count(raw_input)) {
mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input];
}
}

for (const auto n : seg.nodes()) {
util::cloneNode(n, g, mini_to_new_g);
}

// original graph value => new global graph value
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]];
}
size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0;
for (size_t i = 0; i < seg.raw_inputs().size(); ++i) {
if (!old_to_new_g.count(seg.raw_inputs()[i])) {
old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]];
}
}

return;
}

typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
GraphAndMapping;

void AddIfBlockToGraph(
std::shared_ptr<torch::jit::Graph>& new_g,
torch::jit::Node* if_node,
const std::vector<GraphAndMapping>& graph_and_mappings,
std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
torch::jit::IfView if_view(if_node);

// create a new if node in new_g and add corresponding inputs
auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0));
new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g));

// iterate over all blocks and add them to new created prim::If
for (auto graph_and_mapping : graph_and_mappings) {
auto new_if_block = new_if->addBlock();
auto cur_block_graph = graph_and_mapping.first;
auto cur_block_mapping = graph_and_mapping.second;
std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
for (auto& i : cur_block_mapping) {
// for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
// it's mini graph's input
if (old_to_new_g.count(i.first)) {
block_graph_to_new_g[i.second] = old_to_new_g[i.first];
}
}

auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
new_if_block->cloneFrom(cur_block_graph->block(), env);
if (cur_block_graph->inputs().size() &&
cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = new_g->insertInput(0, "self_1");
self->setType(cur_block_graph->inputs()[0]->type());
}
block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0];
}
for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) {
new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]);
new_if_block->eraseInput(i);
}
}
for (auto ov : if_view.outputs()) {
auto no = new_if->addOutput();
old_to_new_g[ov] = no;
no->copyMetadata(ov);
}
return;
}

GraphAndMapping ConstructFallbackGraph(
partitioning::GraphAndMapping BuildHybridGraph(
torch::jit::script::Module& new_mod,
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
CompileSpec cfg,
ir::StaticParams static_params,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
auto convert_cfg = cfg.convert_info;
auto partition_info = cfg.partition_info;

auto new_g = std::make_shared<torch::jit::Graph>();

auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);

// the mapping from lowering graph => fallback global graph
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
for (auto input : block->inputs()) {
util::getOrAddInputForValue(input, new_g, old_to_new_g);
}

for (auto& seg_block : segmented_blocks) {
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
auto shapes = seg_block.in_shapes();
auto types = seg_block.in_types();
std::vector<ir::Input> inputs;
for (size_t i = 0; i < shapes.size(); i++) {
auto in = ir::Input(shapes[i]);
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
inputs.push_back(in);
}
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

seg_block.update_graph(temp_g);
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
} else {
if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) {
auto if_node = seg_block.raw_nodes()[0];

// convert the 2 blocks in prim::if and get the converted graph with mappings
std::vector<GraphAndMapping> graph_and_mappings;
for (auto cur_block : if_node->blocks()) {
graph_and_mappings.push_back(
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
ir::CollectionTypeMap first_use_types) {
auto convert_info = cfg.convert_info;
auto partitioning_info = cfg.partitioning_info;

auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
auto collection_input_ivalues_map =
partitioning::GenerateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);

partitioning::Partition(&partitioning_ctx, collection_input_ivalues_map);

for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;

for (auto& seg_block : segmented_blocks) {
LOG_INFO("Block segment:" << seg_block);
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
auto shapes = seg_block.in_shapes();
auto types = seg_block.in_types();
std::vector<ir::Input> inputs;
for (size_t i = 0; i < shapes.size(); i++) {
auto in = ir::Input(shapes[i]);
in.dtype = util::ScalarTypeToTRTDataType(types[i]);
inputs.push_back(in);
}
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);
// update the input ranges for each segments
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

} else {
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
}
}
}
// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

if (block->outputs().size() > 1) {
std::vector<torch::jit::Value*> fallback_graph_vector;
for (auto& output : block->outputs()) {
if (old_to_new_g.count(output)) {
fallback_graph_vector.push_back(old_to_new_g[output]);
seg_block.update_graph(temp_g);
}
}
torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs(fallback_graph_vector);
auto return_tuple_node = new_g->createTuple(fallback_graph_outputs);
new_g->block()->appendNode(return_tuple_node);
// Set the output as the produced tuple
new_g->registerOutput(return_tuple_node->outputs()[0]);
} else {
if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
}
}
return {new_g, old_to_new_g};

return partitioning::Stitch(&partitioning_ctx, block);
}

void MapInputsAndDetermineDTypes(
Expand All @@ -310,6 +184,8 @@ void MapInputsAndDetermineDTypes(
ir::CollectionTypeMap& first_use_type_map) {
cfg.convert_info.collection_input_spec_map =
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
cfg.partitioning_info.collection_input_spec_map =
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);

auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
Expand Down Expand Up @@ -339,7 +215,7 @@ void MapInputsAndDetermineDTypes(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = nvinfer1::DataType::kFLOAT;
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
Expand Down Expand Up @@ -424,22 +300,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partition_info.enabled &&
if (cfg.partitioning_info.enabled &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
!outputIsCollection) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partition_info.enabled &&
if (cfg.partitioning_info.enabled &&
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection)) {
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
auto collection_input_ivalues_map =
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(
new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CompileSpec {
ir::GraphInputs graph_inputs;
conversion::ConversionInfo convert_info;
lowering::LowerInfo lower_info;
partitioning::PartitionInfo partition_info;
partitioning::PartitioningInfo partitioning_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::MarkNodesForFallback(g, true);
}
passes::UnpackHardSwish(g);
passes::UnpackHardSigmoid(g);
passes::EliminateExceptionOrPassPattern(g);
passes::ReduceToOperation(g);
passes::ReduceGelu(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
"silu_to_sigmoid_multiplication.cpp",
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
"unpack_hardsigmoid.cpp",
"unpack_hardswish.cpp",
"unpack_log_softmax.cpp",
"unpack_std.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ target_sources(${lib_name}
"${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);

} // namespace passes
} // namespace lowering
Expand Down
43 changes: 43 additions & 0 deletions core/lowering/passes/unpack_hardsigmoid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph) {
std::string hardsigmoid_pattern = R"IR(
graph(%input):
%result = aten::hardsigmoid(%input)
return (%result))IR";

std::string hardsigmoid_pattern_inplace = R"IR(
graph(%input):
%result = aten::hardsigmoid_(%input)
return (%result))IR";

std::string new_pattern = R"IR(
graph(%x.1):
%22 : float = prim::Constant[value=0.5]()
%3 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=1]()
%10 : int = prim::Constant[value=0]()
%4 : Tensor = aten::div(%x.1, %3)
%9 : Tensor = aten::add(%4, %22, %5)
%21 : Tensor = aten::clamp(%9, %10, %5)
return (%21))IR";

torch::jit::SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(hardsigmoid_pattern, new_pattern);
rewriter.RegisterRewritePattern(hardsigmoid_pattern_inplace, new_pattern);
rewriter.runOnGraph(graph);

LOG_GRAPH("Post unpack hardsigmoid: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
12 changes: 4 additions & 8 deletions core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@ config_setting(
cc_library(
name = "partitioning",
srcs = [
"PartitionInfo.cpp",
"SegmentedBlock.cpp",
"partitioning.cpp",
"shape_analysis.cpp",
"stitching.cpp",
],
hdrs = [
"PartitionInfo.h",
"SegmentedBlock.h",
"partitioning.h",
"shape_analysis.h",
],
deps = [
"//core/util:prelude",
"//core/ir",
"//core/conversion",
"//core/lowering",
"//core/partitioning/partitioningctx",
"//core/partitioning/partitioninginfo",
"//core/partitioning/segmentedblock",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand All @@ -39,10 +38,7 @@ cc_library(
pkg_tar(
name = "include",
srcs = [
"PartitionInfo.h",
"SegmentedBlock.h",
"partitioning.h",
"shape_analysis.h",
],
package_dir = "core/partitioning/",
)
Loading