Skip to content

Centralizing Partitioning State #1263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,16 @@ void AddIfBlockToGraph(
return;
}

GraphAndMapping ConstructFallbackGraph(
GraphAndMapping ConstructFallbackGraph_(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a more distinguishing name for this maybe?

torch::jit::script::Module& new_mod,
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
CompileSpec cfg,
partitioning::PartitioningCtx* partitioning_ctx,
conversion::ConversionInfo convert_info,
ir::StaticParams static_params,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
auto convert_cfg = cfg.convert_info;
auto partition_info = cfg.partition_info;

std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map) {
auto new_g = std::make_shared<torch::jit::Graph>();

auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);
auto segmented_blocks = partitioning::partition(partitioning_ctx, block, example_tensor_map);

// the mapping from lowering graph => fallback global graph
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
Expand All @@ -240,7 +237,7 @@ GraphAndMapping ConstructFallbackGraph(
}

for (auto& seg_block : segmented_blocks) {
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
LOG_INFO("Block segment:" << seg_block);
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);

Expand All @@ -254,12 +251,12 @@ GraphAndMapping ConstructFallbackGraph(
inputs.push_back(in);
}
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
auto device_spec = convert_info.engine_settings.device;
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);

Expand All @@ -272,8 +269,8 @@ GraphAndMapping ConstructFallbackGraph(
// convert the 2 blocks in prim::if and get the converted graph with mappings
std::vector<GraphAndMapping> graph_and_mappings;
for (auto cur_block : if_node->blocks()) {
graph_and_mappings.push_back(
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
graph_and_mappings.push_back(ConstructFallbackGraph_(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not wildly pressing but is there a way to do all the partitioning beforehand then go through and compile specific blocks? Having them mixed is not as easy to debug

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bowang007 thoughts here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps what we do is recursively partition, then recursively compile the final graph. Not sure if graph stitching can handle this right now.

new_mod, cur_block, partitioning_ctx, convert_info, static_params, example_tensor_map));
}
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);

Expand Down Expand Up @@ -303,13 +300,32 @@ GraphAndMapping ConstructFallbackGraph(
return {new_g, old_to_new_g};
}

GraphAndMapping ConstructFallbackGraph(
torch::jit::script::Module& new_mod,
torch::jit::Block* block,
CompileSpec cfg,
ir::StaticParams static_params,
ir::CollectionTypeMap first_use_types) {
auto convert_info = cfg.convert_info;
auto partitioning_info = cfg.partitioning_info;

auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
auto collection_input_ivalues_map =
partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types);

return ConstructFallbackGraph_(
new_mod, block, &partitioning_ctx, convert_info, static_params, collection_input_ivalues_map);
}

void MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
ir::CollectionTypeMap& first_use_type_map) {
cfg.convert_info.collection_input_spec_map =
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
cfg.partitioning_info.collection_input_spec_map =
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);

auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG(
Expand Down Expand Up @@ -339,7 +355,7 @@ void MapInputsAndDetermineDTypes(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = nvinfer1::DataType::kFLOAT;
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
Expand Down Expand Up @@ -424,21 +440,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partition_info.enabled &&
if (cfg.partitioning_info.enabled &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partition_info.enabled &&
if (cfg.partitioning_info.enabled &&
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
outputIsCollection)) {
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
auto collection_input_ivalues_map =
partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(
new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), cfg, static_params, first_use_types);
new_g = graph_and_mapping.first;
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
for (size_t i = 0; i < new_g->inputs().size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CompileSpec {
ir::GraphInputs graph_inputs;
conversion::ConversionInfo convert_info;
lowering::LowerInfo lower_info;
partitioning::PartitionInfo partition_info;
partitioning::PartitioningInfo partitioning_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::MarkNodesForFallback(g, true);
}
passes::UnpackHardSwish(g);
passes::UnpackHardSigmoid(g);
passes::EliminateExceptionOrPassPattern(g);
passes::ReduceToOperation(g);
passes::ReduceGelu(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
"silu_to_sigmoid_multiplication.cpp",
"unpack_addmm.cpp",
"unpack_batch_norm.cpp",
"unpack_hardsigmoid.cpp",
"unpack_hardswish.cpp",
"unpack_log_softmax.cpp",
"unpack_std.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ target_sources(${lib_name}
"${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);

} // namespace passes
} // namespace lowering
Expand Down
43 changes: 43 additions & 0 deletions core/lowering/passes/unpack_hardsigmoid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph) {
std::string hardsigmoid_pattern = R"IR(
graph(%input):
%result = aten::hardsigmoid(%input)
return (%result))IR";

std::string hardsigmoid_pattern_inplace = R"IR(
graph(%input):
%result = aten::hardsigmoid_(%input)
return (%result))IR";

std::string new_pattern = R"IR(
graph(%x.1):
%22 : float = prim::Constant[value=0.5]()
%3 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=1]()
%10 : int = prim::Constant[value=0]()
%4 : Tensor = aten::div(%x.1, %3)
%9 : Tensor = aten::add(%4, %22, %5)
%21 : Tensor = aten::clamp(%9, %10, %5)
return (%21))IR";

torch::jit::SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(hardsigmoid_pattern, new_pattern);
rewriter.RegisterRewritePattern(hardsigmoid_pattern_inplace, new_pattern);
rewriter.runOnGraph(graph);

LOG_GRAPH("Post unpack hardsigmoid: " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
11 changes: 3 additions & 8 deletions core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@ config_setting(
cc_library(
name = "partitioning",
srcs = [
"PartitionInfo.cpp",
"SegmentedBlock.cpp",
"partitioning.cpp",
"shape_analysis.cpp",
],
hdrs = [
"PartitionInfo.h",
"SegmentedBlock.h",
"partitioning.h",
"shape_analysis.h",
],
deps = [
"//core/util:prelude",
"//core/ir",
"//core/conversion",
"//core/lowering",
"//core/partitioning/partitioningctx",
"//core/partitioning/partitioninginfo",
"//core/partitioning/segmentedblock",
] + select({
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
Expand All @@ -39,10 +37,7 @@ cc_library(
pkg_tar(
name = "include",
srcs = [
"PartitionInfo.h",
"SegmentedBlock.h",
"partitioning.h",
"shape_analysis.h",
],
package_dir = "core/partitioning/",
)
38 changes: 22 additions & 16 deletions core/partitioning/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
set(lib_name "core_partitioning")
add_library(${lib_name} OBJECT)

target_sources(${lib_name}
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/PartitionInfo.cpp"
$<TARGET_OBJECTS:core_conversion>
PUBLIC $<TARGET_OBJECTS:core_ir>
$<TARGET_OBJECTS:core_util>
set(CXX_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.cpp"
)

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.h"
"${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.h"
"${CMAKE_CURRENT_SOURCE_DIR}/PartitionInfo.h"
"${CMAKE_CURRENT_SOURCE_DIR}/partitioning.h"
)

target_include_directories(${lib_name} PUBLIC "$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}>")
target_sources(${lib_name}
PRIVATE
${CXX_SRCS}
PUBLIC
$<TARGET_OBJECTS:core_conversion>
$<TARGET_OBJECTS:core_ir>
$<TARGET_OBJECTS:core_util>
)

target_link_libraries(${lib_name}
PUBLIC
torch
TensorRT::nvinfer
torch
core_ir
core_util
PRIVATE
core_conversion
)

# Install headers
install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/")
target_include_directories(${lib_name}
PUBLIC "$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}>"
)

add_subdirectory(partitioningctx)
add_subdirectory(partitioninginfo)
add_subdirectory(segmentedblock)

install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning")
Loading