diff --git a/core/compiler.cpp b/core/compiler.cpp index bf128b714a..118ca7aa1c 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -11,7 +11,6 @@ #include "torch/csrc/jit/frontend/function_schema_parser.h" #include "torch/csrc/jit/ir/ir.h" -#include "torch/csrc/jit/ir/ir_views.h" #include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/loop_unrolling.h" #include "torch/csrc/jit/passes/lower_graph.h" @@ -128,179 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri return conversion::VerifyConverterSupportForBlock(g->block()); } -void AddSegmentedBlockToGraph( - std::shared_ptr& g, - partitioning::SegmentedBlock& seg, - std::unordered_map& old_to_new_g) { - // old_to_new_g contains: original global graph value => new global graph value, - // mini_to_new_g: mini graph value -> new graph value - std::unordered_map mini_to_new_g; - size_t input_idx = 0; - if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) { - if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { - auto self = g->insertInput(0, "self_1"); - self->setType(seg.inputs()[0]->type()); - } - mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0]; - } - - for (auto& raw_input : seg.raw_inputs()) { - if (old_to_new_g.count(raw_input)) { - mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input]; - } - } - - for (const auto n : seg.nodes()) { - util::cloneNode(n, g, mini_to_new_g); - } - - // original graph value => new global graph value - for (size_t i = 0; i < seg.raw_outputs().size(); ++i) { - old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]]; - } - size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0; - for (size_t i = 0; i < seg.raw_inputs().size(); ++i) { - if (!old_to_new_g.count(seg.raw_inputs()[i])) { - old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]]; - } - } - - return; -} - -typedef std::pair, std::unordered_map> - GraphAndMapping; - -void AddIfBlockToGraph( - std::shared_ptr& new_g, - torch::jit::Node* if_node, - const std::vector& graph_and_mappings, - std::unordered_map& old_to_new_g) { - torch::jit::IfView if_view(if_node); - - // create a new if node in new_g and add corresponding inputs - auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); - new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); - - // iterate over all blocks and add them to new created prim::If - for (auto graph_and_mapping : graph_and_mappings) { - auto new_if_block = new_if->addBlock(); - auto cur_block_graph = graph_and_mapping.first; - auto cur_block_mapping = graph_and_mapping.second; - std::unordered_map block_graph_to_new_g; - for (auto& i : cur_block_mapping) { - // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then - // it's mini graph's input - if (old_to_new_g.count(i.first)) { - block_graph_to_new_g[i.second] = old_to_new_g[i.first]; - } - } - - auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; - new_if_block->cloneFrom(cur_block_graph->block(), env); - if (cur_block_graph->inputs().size() && - cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { - if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { - auto self = new_g->insertInput(0, "self_1"); - self->setType(cur_block_graph->inputs()[0]->type()); - } - block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; - } - for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { - new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); - new_if_block->eraseInput(i); - } - } - for (auto ov : if_view.outputs()) { - auto no = new_if->addOutput(); - old_to_new_g[ov] = no; - no->copyMetadata(ov); - } - return; -} - -GraphAndMapping ConstructFallbackGraph( +partitioning::GraphAndMapping BuildHybridGraph( torch::jit::script::Module& new_mod, torch::jit::Block* block, - std::unordered_map example_tensor_map, CompileSpec cfg, ir::StaticParams static_params, - std::unordered_map& fallback_nodes) { - auto convert_cfg = cfg.convert_info; - auto partition_info = cfg.partition_info; - - auto new_g = std::make_shared(); - - auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes); - - // the mapping from lowering graph => fallback global graph - std::unordered_map old_to_new_g; - for (auto input : block->inputs()) { - util::getOrAddInputForValue(input, new_g, old_to_new_g); - } - - for (auto& seg_block : segmented_blocks) { - LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n"); - std::ostringstream trt_engine_id; - trt_engine_id << reinterpret_cast(&seg_block); - - if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { - auto shapes = seg_block.in_shapes(); - auto types = seg_block.in_types(); - std::vector inputs; - for (size_t i = 0; i < shapes.size(); i++) { - auto in = ir::Input(shapes[i]); - in.dtype = util::ScalarTypeToTRTDataType(types[i]); - inputs.push_back(in); - } - // update the input ranges for each segments - convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); - - // TODO mapping Inputs Ivalue to flatten one here - auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params); - auto temp_g = std::make_shared(); - auto device_spec = convert_cfg.engine_settings.device; - auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); - AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); - - seg_block.update_graph(temp_g); - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); - } else { - if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { - auto if_node = seg_block.raw_nodes()[0]; - - // convert the 2 blocks in prim::if and get the converted graph with mappings - std::vector graph_and_mappings; - for (auto cur_block : if_node->blocks()) { - graph_and_mappings.push_back( - ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes)); + ir::CollectionTypeMap first_use_types) { + auto convert_info = cfg.convert_info; + auto partitioning_info = cfg.partitioning_info; + + auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info); + auto collection_input_ivalues_map = + partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types); + + partitioning::partition(&partitioning_ctx, collection_input_ivalues_map); + + for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) { + partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second; + + for (auto& seg_block : segmented_blocks) { + LOG_INFO("Block segment:" << seg_block); + std::ostringstream trt_engine_id; + trt_engine_id << reinterpret_cast(&seg_block); + + if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { + auto shapes = seg_block.in_shapes(); + auto types = seg_block.in_types(); + std::vector inputs; + for (size_t i = 0; i < shapes.size(); i++) { + auto in = ir::Input(shapes[i]); + in.dtype = util::ScalarTypeToTRTDataType(types[i]); + inputs.push_back(in); } - AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); + // update the input ranges for each segments + convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); - } else { - AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); - } - } - } + // TODO mapping Inputs Ivalue to flatten one here + auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params); + auto temp_g = std::make_shared(); + auto device_spec = convert_info.engine_settings.device; + auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type); + AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true); - if (block->outputs().size() > 1) { - std::vector fallback_graph_vector; - for (auto& output : block->outputs()) { - if (old_to_new_g.count(output)) { - fallback_graph_vector.push_back(old_to_new_g[output]); + seg_block.update_graph(temp_g); } } - torch::jit::ArrayRef fallback_graph_outputs(fallback_graph_vector); - auto return_tuple_node = new_g->createTuple(fallback_graph_outputs); - new_g->block()->appendNode(return_tuple_node); - // Set the output as the produced tuple - new_g->registerOutput(return_tuple_node->outputs()[0]); - } else { - if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) { - new_g->registerOutput(old_to_new_g[block->outputs()[0]]); - } } - return {new_g, old_to_new_g}; + + return partitioning::stitch(&partitioning_ctx, block); } void MapInputsAndDetermineDTypes( @@ -310,6 +184,8 @@ void MapInputsAndDetermineDTypes( ir::CollectionTypeMap& first_use_type_map) { cfg.convert_info.collection_input_spec_map = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params)); + cfg.partitioning_info.collection_input_spec_map = + ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map); auto collection_inputs = ir::get_collection_inputs(g, static_params); LOG_DEBUG( @@ -339,7 +215,7 @@ void MapInputsAndDetermineDTypes( "Cannot infer input type from calcuations in graph for input " << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); spec[i].dtype = nvinfer1::DataType::kFLOAT; - } else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) { + } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) { if (!est_type_opt[i]) { LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); std::stringstream ss; @@ -424,22 +300,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); auto outputIsCollection = conversion::OutputIsCollection(g->block()); - if (cfg.partition_info.enabled && + if (cfg.partitioning_info.enabled && (cfg.lower_info.forced_fallback_modules.size() == 0 && - cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) && + cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) && !outputIsCollection) { LOG_INFO("Skipping partitioning since model is fully supported"); } - if (cfg.partition_info.enabled && + if (cfg.partitioning_info.enabled && (!(cfg.lower_info.forced_fallback_modules.size() == 0 && - cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || + cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || outputIsCollection)) { - std::unordered_map fallback_nodes; - auto collection_input_ivalues_map = - partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types); - auto graph_and_mapping = ConstructFallbackGraph( - new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes); + auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types); new_g = graph_and_mapping.first; // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly for (size_t i = 0; i < new_g->inputs().size(); ++i) { diff --git a/core/compiler.h b/core/compiler.h index c8dc85020b..1b7b3defe8 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -19,7 +19,7 @@ struct CompileSpec { ir::GraphInputs graph_inputs; conversion::ConversionInfo convert_info; lowering::LowerInfo lower_info; - partitioning::PartitionInfo partition_info; + partitioning::PartitioningInfo partitioning_info; }; bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 8bbae296c3..5442440422 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -41,6 +41,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::MarkNodesForFallback(g, true); } passes::UnpackHardSwish(g); + passes::UnpackHardSigmoid(g); passes::EliminateExceptionOrPassPattern(g); passes::ReduceToOperation(g); passes::ReduceGelu(g); diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 1f6a0cde8f..d5f3616f8d 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -30,6 +30,7 @@ cc_library( "silu_to_sigmoid_multiplication.cpp", "unpack_addmm.cpp", "unpack_batch_norm.cpp", + "unpack_hardsigmoid.cpp", "unpack_hardswish.cpp", "unpack_log_softmax.cpp", "unpack_std.cpp", diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index a8cda65e71..48e644a70d 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -17,6 +17,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/silu_to_sigmoid_multiplication.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_addmm.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_batch_norm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardsigmoid.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_hardswish.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_log_softmax.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp" diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 73bd9f61d7..3b946593e2 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); void SiluToSigmoidMultipication(std::shared_ptr& graph); void UnpackHardSwish(std::shared_ptr& graph); +void UnpackHardSigmoid(std::shared_ptr& graph); } // namespace passes } // namespace lowering diff --git a/core/lowering/passes/unpack_hardsigmoid.cpp b/core/lowering/passes/unpack_hardsigmoid.cpp new file mode 100644 index 0000000000..876196215a --- /dev/null +++ b/core/lowering/passes/unpack_hardsigmoid.cpp @@ -0,0 +1,43 @@ +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void UnpackHardSigmoid(std::shared_ptr& graph) { + std::string hardsigmoid_pattern = R"IR( + graph(%input): + %result = aten::hardsigmoid(%input) + return (%result))IR"; + + std::string hardsigmoid_pattern_inplace = R"IR( + graph(%input): + %result = aten::hardsigmoid_(%input) + return (%result))IR"; + + std::string new_pattern = R"IR( + graph(%x.1): + %22 : float = prim::Constant[value=0.5]() + %3 : int = prim::Constant[value=6]() + %5 : int = prim::Constant[value=1]() + %10 : int = prim::Constant[value=0]() + %4 : Tensor = aten::div(%x.1, %3) + %9 : Tensor = aten::add(%4, %22, %5) + %21 : Tensor = aten::clamp(%9, %10, %5) + return (%21))IR"; + + torch::jit::SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(hardsigmoid_pattern, new_pattern); + rewriter.RegisterRewritePattern(hardsigmoid_pattern_inplace, new_pattern); + rewriter.runOnGraph(graph); + + LOG_GRAPH("Post unpack hardsigmoid: " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/core/partitioning/BUILD b/core/partitioning/BUILD index fbc9eeac7a..4204939684 100644 --- a/core/partitioning/BUILD +++ b/core/partitioning/BUILD @@ -13,22 +13,21 @@ config_setting( cc_library( name = "partitioning", srcs = [ - "PartitionInfo.cpp", - "SegmentedBlock.cpp", "partitioning.cpp", "shape_analysis.cpp", + "stitching.cpp", ], hdrs = [ - "PartitionInfo.h", - "SegmentedBlock.h", "partitioning.h", - "shape_analysis.h", ], deps = [ "//core/util:prelude", "//core/ir", "//core/conversion", "//core/lowering", + "//core/partitioning/partitioningctx", + "//core/partitioning/partitioninginfo", + "//core/partitioning/segmentedblock", ] + select({ ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], "//conditions:default": ["@libtorch//:libtorch"], @@ -39,10 +38,7 @@ cc_library( pkg_tar( name = "include", srcs = [ - "PartitionInfo.h", - "SegmentedBlock.h", "partitioning.h", - "shape_analysis.h", ], package_dir = "core/partitioning/", ) diff --git a/core/partitioning/CMakeLists.txt b/core/partitioning/CMakeLists.txt index 15784f638e..7f83b3d891 100644 --- a/core/partitioning/CMakeLists.txt +++ b/core/partitioning/CMakeLists.txt @@ -1,33 +1,39 @@ set(lib_name "core_partitioning") add_library(${lib_name} OBJECT) -target_sources(${lib_name} - PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/PartitionInfo.cpp" - $ - PUBLIC $ - $ +set(CXX_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.cpp" ) set(HEADER_FILES - "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.h" - "${CMAKE_CURRENT_SOURCE_DIR}/shape_analysis.h" - "${CMAKE_CURRENT_SOURCE_DIR}/PartitionInfo.h" "${CMAKE_CURRENT_SOURCE_DIR}/partitioning.h" ) -target_include_directories(${lib_name} PUBLIC "$") +target_sources(${lib_name} + PRIVATE + ${CXX_SRCS} + PUBLIC + $ + $ + $ +) + target_link_libraries(${lib_name} PUBLIC - torch TensorRT::nvinfer + torch core_ir core_util - PRIVATE core_conversion ) -# Install headers -install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/") +target_include_directories(${lib_name} + PUBLIC "$" +) + +add_subdirectory(partitioningctx) +add_subdirectory(partitioninginfo) +add_subdirectory(segmentedblock) + +install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning") \ No newline at end of file diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index c329b33ef6..eb8c86de50 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -1,9 +1,7 @@ -#include "partitioning.h" - +#include "core/partitioning/partitioning.h" #include #include "core/conversion/conversion.h" #include "core/conversion/evaluators/evaluators.h" -#include "core/partitioning/shape_analysis.h" #include "torch/csrc/jit/passes/constant_pooling.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" @@ -30,6 +28,132 @@ bool containNonTensorOutputs(torch::jit::Node* n) { return false; } +// Check if the inputs and outputs of the graph are Tensor. If not, then fallback connected nodes +void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* block) { + // fallback nodes that produce entire graph's nonTensor output + for (auto i : block->outputs()) { + if (!isTensor(i)) { + ctx->setNodeExecutorDecision(i->node(), NodeExecutorDecision::kNON_TENSOR); + } + } + + // fallback nodes that consume entire graph's nonTensor input + for (auto i : block->inputs()) { + if (!isTensor(i)) { + for (auto use : i->uses()) { + ctx->setNodeExecutorDecision(use.user, NodeExecutorDecision::kNON_TENSOR); + } + } + } +} + +// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback) +// we use a map to indicate the reason why it's fallback to torch +// For any node that's not explicitly fallback, we set it to run in TensorRT for now +void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) { + auto nodes = block->nodes(); + const auto to_compile_sym = c10::Symbol::attr("to_compile"); + + for (const auto n : nodes) { + if (n->kind() == torch::jit::prim::Constant) { + continue; + } + + if (!conversion::OpSupported(n)) { + // If the op is not supported by the conversion phase it should run in PyTorch + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED); + } else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) { + // If the user specifies the op to run in Torch it should run in PyTorch + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kOPERATOR_FALLBACK); + } else if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { + // If the user specifies the module containing this op to run in torch it should run in PyTorch + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kMODULE_FALLBACK); + } else { + // Set the rest nodes to TensorRt + ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT); + } + } + return; +} + +// For a given set of fallback nodes, check their inputs/outputs, if any inputs/outputs of them are NonTensor, +// then the nodes that produces/consumes those values should also fallback +void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector& initial_fallback_nodes) { + // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function + std::queue q; + for (auto& node : initial_fallback_nodes) { + q.push(node); + } + + while (!q.empty()) { + auto cur_node = q.front(); + q.pop(); + // for every node that produces this fallback node's NonTensor input, they should fallback too + for (auto input : cur_node->inputs()) { + if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && + ctx->shouldNodeRunInTensorRT(input->node())) { + ctx->setNodeExecutorDecision(input->node(), NodeExecutorDecision::kNON_TENSOR); + q.push(input->node()); + } + } + // for every node that consumes this fallback node's NonTensor output, they should fallback too + for (auto output : cur_node->outputs()) { + if (!isTensor(output)) { + for (auto use : output->uses()) { + auto node = use.user; + if (node->kind() != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT(node)) { + ctx->setNodeExecutorDecision(node, NodeExecutorDecision::kNON_TENSOR); + q.push(node); + } + } + } + } + } +} + +// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size +std::vector traverseNodesForMinBlockSize(PartitioningCtx* ctx, torch::jit::Block* block) { + auto nodes = block->nodes(); + std::vector cur_trt_nodes; + std::vector min_block_fallback_nodes; + for (const auto n : nodes) { + if (n->kind() == torch::jit::prim::Constant) { + continue; + } + + // check if current node fallback or not + if (!ctx->shouldNodeRunInTorch(n)) { + cur_trt_nodes.push_back(n); + } else { + if (cur_trt_nodes.size() < ctx->settings.min_block_size) { + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); + } + cur_trt_nodes.clear(); + } + } + if (cur_trt_nodes.size() < ctx->settings.min_block_size) { + min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); + } + return min_block_fallback_nodes; +} + +// Set the nodes that fallback because of min_block_size +void setMinBlockFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) { + // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement + auto min_block_fallback_nodes = traverseNodesForMinBlockSize(ctx, block); + + // keep fallback until all segments meet the min_block_size requirement + while (!min_block_fallback_nodes.empty()) { + for (const auto i : min_block_fallback_nodes) { + ctx->setNodeExecutorDecision(i, NodeExecutorDecision::kMIN_BLOCK_FALLBACK); + } + // find the fallback nodes because of dependency with min_block_size caused fallback nodes + setNonTensorConnectedNodes(ctx, min_block_fallback_nodes); + // keep traverse the graph until there is no node fallback because of min_block_size + min_block_fallback_nodes = traverseNodesForMinBlockSize(ctx, block); + } +} + bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) { const torch::jit::FunctionSchema* schema = node->maybeSchema(); if (!schema) { @@ -96,91 +220,38 @@ std::vector getDependencyNodes( return stk; } -// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related -// nodes -void fallback_graph_nontensor_in_out( - torch::jit::Block* block, - std::unordered_map& global_fallback_nodes) { - // fallback nodes that produce entire graph's nonTensor output - for (auto i : block->outputs()) { - if (!isTensor(i)) { - global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR}); - } - } - - // fallback nodes that consume entire graph's nonTensor input - for (auto i : block->inputs()) { - if (!isTensor(i)) { - for (auto use : i->uses()) { - global_fallback_nodes.insert({use.user, FallbackNodeType::kNON_TENSOR}); - } - } - } -} - -void find_all_fallback_nodes( - std::unordered_map& initial_fallback_nodes, - std::unordered_map& global_fallback_nodes) { - // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function - // global_fallback_nodes are the fallback nodes that we maintain globally - std::queue q; - for (auto& node : initial_fallback_nodes) { - q.push(node.first); - } - - std::unordered_set visited_nodes; - while (!q.empty()) { - auto cur_node = q.front(); - q.pop(); - // for every node that produces this fallback node's NonTensor input, they should fallback too - for (auto input : cur_node->inputs()) { - if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && - global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) { - q.push(input->node()); - } - } - // for every node that consumes this fallback node's NonTensor output, they should fallback too - for (auto output : cur_node->outputs()) { - if (!isTensor(output)) { - for (auto use : output->uses()) { - auto node = use.user; - if (node->kind() != torch::jit::prim::Constant && - global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) { - q.push(node); - } - } - } - } - } -} - -void resolveTRTNonTensorInputs(PartitionedGraph& segmented_blocks) { +void resolveTRTNonTensorInputs(PartitioningCtx* ctx, torch::jit::Block* block) { // if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine // because we have already found the interface between Torch and TRT in segmentation phase // what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs - for (size_t i = 0; i < segmented_blocks.size(); ++i) { - if (segmented_blocks[i].target() == SegmentedBlock::kTensorRT) { + PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks[block]; + for (size_t i = 0; i < cur_partitioned_block.size(); ++i) { + if (cur_partitioned_block[i].target() == SegmentedBlock::kTensorRT) { std::vector inputs_to_resolve; - for (auto input : segmented_blocks[i].raw_inputs()) { + for (auto input : cur_partitioned_block[i].raw_inputs()) { if (!isTensor(input)) { inputs_to_resolve.push_back(input); } } if (!inputs_to_resolve.empty()) { - std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve, segmented_blocks[i]); + std::vector dependency_nodes = + getDependencyNodes(inputs_to_resolve, cur_partitioned_block[i]); dependency_nodes.insert( - dependency_nodes.end(), segmented_blocks[i].raw_nodes().begin(), segmented_blocks[i].raw_nodes().end()); - segmented_blocks[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes); + dependency_nodes.end(), + cur_partitioned_block[i].raw_nodes().begin(), + cur_partitioned_block[i].raw_nodes().end()); + cur_partitioned_block[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes); } } } } -void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) { +void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) { // find the corresponding raw values in original global graph for this segmented block's inputs/outputs + PartitionedGraph& cur_partitioned_block = ctx->partitioned_blocks[block]; auto cmp = [](torch::jit::Value* a, torch::jit::Value* b) { return a->unique() < b->unique(); }; std::set input_values(cmp); - for (auto& seg_block : segmented_blocks) { + for (auto& seg_block : cur_partitioned_block) { for (auto& input : seg_block.raw_inputs()) { input_values.insert(input); } @@ -193,7 +264,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo // should be careful here because some in-place operations don't return any values, there is no output for this kind // of segment identify the output for each mini-graph by checking if any value in this graph is used later we // shouldn't register nonTensor output for TensorRT segments - for (auto& seg_block : segmented_blocks) { + for (auto& seg_block : cur_partitioned_block) { for (auto& mini_graph_input : input_values) { if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) == seg_block.raw_inputs().end() && @@ -222,20 +293,21 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo } } - std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) { + std::for_each(cur_partitioned_block.begin(), cur_partitioned_block.end(), [](SegmentedBlock& seg_block) { torch::jit::EliminateDeadCode(seg_block.g()); }); // erase segments which still have no output - segmented_blocks.erase( + cur_partitioned_block.erase( std::remove_if( - segmented_blocks.begin(), - segmented_blocks.end(), + cur_partitioned_block.begin(), + cur_partitioned_block.end(), [](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }), - segmented_blocks.end()); + cur_partitioned_block.end()); return; } +// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback bool checkLoopEvaluatable(torch::jit::Node* n) { bool compile_to_trt = true; for (auto bn : n->blocks()[0]->nodes()) { @@ -250,29 +322,7 @@ bool checkLoopEvaluatable(torch::jit::Node* n) { return compile_to_trt; } -bool check_node_fallback(torch::jit::Node* n, const std::unordered_map& fallback_nodes) { - if (fallback_nodes.count(n)) { - if (fallback_nodes.at(n) == FallbackNodeType::kUNSUPPORTED) { - LOG_GRAPH("Node not supported by conversion: " << util::node_info(n)); - } else if (fallback_nodes.at(n) == FallbackNodeType::kOPERATOR_FALLBACK) { - LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n)); - } else if (fallback_nodes.at(n) == FallbackNodeType::kMODULE_FALLBACK) { - LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n)); - } else if (fallback_nodes.at(n) == FallbackNodeType::kMIN_BLOCK_FALLBACK) { - LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n)); - } else { - LOG_GRAPH( - "Node fallback to Torch because the NonTensor dependencies with other fallback nodes: " - << util::node_info(n)); - } - return false; - } - - LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n)); - return true; -} - -void finalize_block( +void finalizeNewBlock( PartitionedGraph& g, SegmentedBlock::SegmentedBlockTarget kind, std::vector& nodes) { @@ -282,110 +332,38 @@ void finalize_block( LOG_DEBUG(g.back()); } -// use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback) -// we use a map to indicate the reason why it's fallback to torch -void get_fallback_nodes( - torch::jit::Block* block, - const std::unordered_set& forced_fallback_ops, - std::unordered_map& fallback_nodes) { - auto nodes = block->nodes(); - for (const auto n : nodes) { - if (n->kind() == torch::jit::prim::Constant) { - continue; - } - - // If the op is not supported by the conversion phase it should run in PyTorch - if (!conversion::OpSupported(n)) { - fallback_nodes.insert({n, FallbackNodeType::kUNSUPPORTED}); - } - - // If the user specifies the op to run in Torch it should run in PyTorch - if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) { - fallback_nodes.insert({n, FallbackNodeType::kOPERATOR_FALLBACK}); - } - - // If the user specifies the module containing this op to run in torch it should run in PyTorch - const auto to_compile_sym = c10::Symbol::attr("to_compile"); - if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { - fallback_nodes.insert({n, FallbackNodeType::kMODULE_FALLBACK}); - } - } - return; +void setNodeExecutorLUT(PartitioningCtx* ctx, torch::jit::Block* block) { + // First, find all the explicit fallback nodes that should run in Torch: + // 1. nodes that are unsupported + // 2. nodes that the user specifies to run in torch + // 3. nodes that the user specifies the module containing this op to run in torch + // At the same time, set all the rest nodes to NodeExecutorDecision::kCONVERT + setExplicitFallbackNodes(ctx, block); + + // Second, check if there is nonTensor input/output for the block, if there is, then fallback the nodes that + // consume/produce this nonTensor value + setInputsOutputsConnectedNodes(ctx, block); + + // Third, for fallback nodes, if it consumes any NonTensor inputs, then the nodes that produce this + // input should also fallback. Similarly, if it produces any NonTensor outputs, then the nodes + // that consume this output should also fallback + auto cur_fallback_nodes = ctx->getNodesRunInTorch(); + setNonTensorConnectedNodes(ctx, cur_fallback_nodes); + + // Finally, check if all current tensorrt blocks satisfy the min_block_size requirement. + // We need to traverse the whole graph many times here + setMinBlockFallbackNodes(ctx, block); } -std::vector traverse_nodes_for_min_block_size( - torch::jit::Block* block, - const std::unordered_map& global_fallback_nodes, - size_t min_block_size) { - auto nodes = block->nodes(); - std::vector cur_trt_nodes; - std::vector min_block_fallback_nodes; - for (const auto n : nodes) { - if (n->kind() == torch::jit::prim::Constant) - continue; - - // check if current node fallback or not - if (!global_fallback_nodes.count(n)) { - // if this node is not in fallback nodes, then it's in trt segments - cur_trt_nodes.push_back(n); - } else { - if (cur_trt_nodes.size() < min_block_size) { - min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); - } - cur_trt_nodes.clear(); - } - } - if (cur_trt_nodes.size() < min_block_size) { - min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end()); - } - return min_block_fallback_nodes; -} - -void find_min_block_size_fallback_nodes( - torch::jit::Block* block, - std::unordered_map& global_fallback_nodes, - size_t min_block_size) { - // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement - auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size); - std::unordered_map initial_fallback_nodes; - - // keep fallback until all segments meet the min_block_size requirement - while (!min_block_fallback_nodes.empty()) { - for (const auto i : min_block_fallback_nodes) { - initial_fallback_nodes.insert({i, FallbackNodeType::kMIN_BLOCK_FALLBACK}); - } - global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end()); - // find the fallback nodes because of dependency with min_block_size caused fallback nodes - find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes); - // keep traverse the graph until there is no node fallback because of min_block_size - min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size); - } -} - -PartitionedGraph segment_graph( - torch::jit::Block* block, - const PartitionInfo& partition_info, - std::unordered_map& global_fallback_nodes) { - auto min_block_size = partition_info.min_block_size; - std::unordered_set forced_fallback_ops( - partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); - - // get the initial fallback nodes (nodes that are unsupported or forced fallback) - get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes); - - // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this - // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node - // that produces this input should also fallback - // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported - find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes); - - // find all fallback nodes because of the min_block_size requirement - find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size); +void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { + // Find all the fallback nodes and build execution decision LUT for all nodes + setNodeExecutorLUT(ctx, block); auto nodes = block->nodes(); - PartitionedGraph segmented_blocks; // segment the nodes + PartitionedGraph segmented_blocks; + std::vector in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes; for (const auto n : nodes) { // Skip constant nodes as they are resources for both kinds of modules @@ -393,22 +371,24 @@ PartitionedGraph segment_graph( continue; } // the outputs of trt subgraph shouldn't be collections - if (check_node_fallback(n, global_fallback_nodes)) { + if (ctx->shouldNodeRunInTensorRT(n)) { in_prog_trt_blk_nodes.push_back(n); // If there is an active PyTorch block and we have passed the threshold for a valid TRT // block then segment and reset the active PyTorch block - if (in_prog_trt_blk_nodes.size() >= min_block_size && !in_prog_pyt_blk_nodes.empty()) { - finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size && !in_prog_pyt_blk_nodes.empty()) { + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } } else { // If there is an active TRT block that is valid segment and reset the active TRT block // otherwise add it to the active PyTorch block and reset - if (in_prog_trt_blk_nodes.size() >= min_block_size) { - finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); + if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) { + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); } else { LOG_DEBUG( - "In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block"); + "In progress TRT block does not meet minimum block size requirements (" + << in_prog_trt_blk_nodes.size() << ", expected at least " << ctx->settings.min_block_size + << "), therefore folding into in progress PyTorch block"); in_prog_pyt_blk_nodes.insert( in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); } @@ -419,20 +399,20 @@ PartitionedGraph segment_graph( LOG_DEBUG( "Hit a conditional statement, finializing in progress PYT block and creating a new one for the conditional"); if (!in_prog_pyt_blk_nodes.empty()) { - finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } auto cond_node = std::vector{n}; - finalize_block(segmented_blocks, SegmentedBlock::kTorch, cond_node); + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node); continue; } else if (n->kind() == torch::jit::prim::Loop) { if (!in_prog_pyt_blk_nodes.empty()) { - finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } if (checkLoopEvaluatable(n)) { in_prog_trt_blk_nodes.push_back(n); } else { auto loop_node = std::vector{n}; - finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node); + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node); } continue; } @@ -442,60 +422,39 @@ PartitionedGraph segment_graph( // if there is any kTorch nodes left, then either the last nodes are kTorch or last nodes are kTensorRT but num < // min_block_size - if (in_prog_trt_blk_nodes.size() >= min_block_size) { - finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); + if (in_prog_trt_blk_nodes.size() >= ctx->settings.min_block_size) { + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); } if (!in_prog_pyt_blk_nodes.empty() || !in_prog_trt_blk_nodes.empty()) { in_prog_pyt_blk_nodes.insert( in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); - finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); + finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } - return segmented_blocks; -} - -PartitionedGraph Partition( - torch::jit::Block* block, - std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info, - std::unordered_map& global_fallback_nodes) { - LOG_DEBUG(partition_info); - // if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor - // output - fallback_graph_nontensor_in_out(block, global_fallback_nodes); - - // segment lowering global graph into blocks - LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); - PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes); - // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks - - // resolve nonTensor inputs/outputs - resolveTRTNonTensorInputs(segmented_blocks); - - // register input/output torch::jit::Value for segmented graphs - LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs"); - registerSegmentsOutputs(segmented_blocks, block); + ctx->partitioned_blocks.insert({block, segmented_blocks}); + return; +} - // run shape analysis on each segmented block - runShapeAnalysis(segmented_blocks, example_tensor_map, partition_info); +void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) { + LOG_DEBUG(ctx->settings); - for (uint64_t i = 0; i < segmented_blocks.size(); i++) { - segmented_blocks[i].update_id(i); - } + // Go through all the blocks to do the partitioning + for (torch::jit::Block* block : ctx->original_blocks) { + // segment lowering global graph into blocks + segmentGraph(ctx, block); - LOG_INFO(segmented_blocks); + // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks + // resolve nonTensor inputs/outputs + resolveTRTNonTensorInputs(ctx, block); - return segmented_blocks; -} + // register input/output torch::jit::Value for segmented graphs + LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs"); + registerSegmentsOutputs(ctx, block); -std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g) { - os << "Partitioned Graph: ["; - for (auto b : g) { - os << b; + // run shape analysis on each segmented block + runShapeAnalysis(ctx, block, example_tensor_map); } - os << "]"; - return os; } } // namespace partitioning diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index f1eb38df8a..3038f6c52f 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -3,45 +3,30 @@ #include #include +#include "torch/csrc/jit/ir/ir.h" + #include "core/ir/ir.h" -#include "core/partitioning/PartitionInfo.h" -#include "core/partitioning/SegmentedBlock.h" -#include "core/partitioning/shape_analysis.h" +#include "core/partitioning/partitioningctx/PartitioningCtx.h" #include "core/util/prelude.h" -#include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { namespace core { namespace partitioning { -typedef std::vector PartitionedGraph; - -enum FallbackNodeType { - /// Node is not supported by TensorRT - kUNSUPPORTED, - /// Node is explicitly forced to fallback to Pytorch due to operator fallback - kOPERATOR_FALLBACK, - /// Node is explicitly forced to fallback to Pytorch due to module fallback - kMODULE_FALLBACK, - /// This node is in a TRT segment which does not satisfy min_block_size - /// and hence is forced to fallback. - kMIN_BLOCK_FALLBACK, - /// This node produces/consumes non-tensor inputs - kNON_TENSOR, -}; - -PartitionedGraph segment_graph( - torch::jit::Block* block, - const PartitionInfo& partition_info, - std::unordered_map& fallback_nodes); - -PartitionedGraph Partition( - torch::jit::Block* block, - std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info, - std::unordered_map& fallback_nodes); - -std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g); +typedef std::unordered_map ExampleIValues; + +typedef std::pair, std::unordered_map> + GraphAndMapping; + +ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types); + +void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps); + +void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block); + +GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block); + +void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map); } // namespace partitioning } // namespace core diff --git a/core/partitioning/partitioningctx/BUILD b/core/partitioning/partitioningctx/BUILD new file mode 100644 index 0000000000..6895f8d451 --- /dev/null +++ b/core/partitioning/partitioningctx/BUILD @@ -0,0 +1,40 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_pkg//:pkg.bzl", "pkg_tar") + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "use_pre_cxx11_abi", + values = { + "define": "abi=pre_cxx11_abi", + }, +) + +cc_library( + name = "partitioningctx", + srcs = [ + "PartitioningCtx.cpp", + ], + hdrs = [ + "PartitioningCtx.h", + ], + deps = [ + "//core/util:prelude", + "//core/ir", + "//core/conversion", + "//core/partitioning/segmentedblock", + "//core/partitioning/partitioninginfo", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + alwayslink = True, +) + +pkg_tar( + name = "include", + srcs = [ + "PartitioningCtx.h", + ], + package_dir = "core/partitioning/partitioningctx", +) diff --git a/core/partitioning/partitioningctx/CMakeLists.txt b/core/partitioning/partitioningctx/CMakeLists.txt new file mode 100644 index 0000000000..090167f829 --- /dev/null +++ b/core/partitioning/partitioningctx/CMakeLists.txt @@ -0,0 +1,12 @@ +set(sub_lib_name "partitioningctx") + +target_sources(${lib_name} + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningCtx.cpp" +) + +set(HEADER_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningCtx.h" +) + +# Install headers +install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/${sub_lib_name}") diff --git a/core/partitioning/partitioningctx/PartitioningCtx.cpp b/core/partitioning/partitioningctx/PartitioningCtx.cpp new file mode 100644 index 0000000000..7bcaaea120 --- /dev/null +++ b/core/partitioning/partitioningctx/PartitioningCtx.cpp @@ -0,0 +1,113 @@ +#include + +#include "core/partitioning/partitioningctx/PartitioningCtx.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace partitioning { + +PartitioningCtx::PartitioningCtx(torch::jit::Block* b, PartitioningInfo info) + : settings(info), + forced_fallback_ops(info.forced_fallback_operators.begin(), info.forced_fallback_operators.end()) { + LOG_DEBUG(settings); + _load_nodes_into_decision_map(b); +} + +void PartitioningCtx::_load_nodes_into_decision_map(torch::jit::Block* b) { + if (!b->owningNode() || b->owningNode()->kind() != torch::jit::prim::Loop) { + original_blocks.push_back(b); + } + for (const auto n : b->nodes()) { + if (n->kind() == torch::jit::prim::Constant) { + continue; + } + node_executor_decision_map[n] = NodeExecutorDecision::kUNKNOWN; + for (const auto sub_b : n->blocks()) { + _load_nodes_into_decision_map(sub_b); + } + } +} + +void PartitioningCtx::setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorDecision decision) { + auto iter = node_executor_decision_map.find(n); + auto prev_decision = NodeExecutorDecision::kUNKNOWN; + if (iter != node_executor_decision_map.end()) { + prev_decision = iter->second; + } + LOG_DEBUG("Setting node " << util::node_info(n) << " " << decision << " (previously was " << prev_decision << ")"); + + node_executor_decision_map[n] = decision; + return; +} + +bool PartitioningCtx::shouldNodeRunInTorch(torch::jit::Node* n) { + auto iter = node_executor_decision_map.find(n); + auto decision = NodeExecutorDecision::kUNKNOWN; + + if (iter != node_executor_decision_map.end()) { + decision = iter->second; + } + if (decision == NodeExecutorDecision::kCONVERT || decision == NodeExecutorDecision::kUNKNOWN) { + return false; + } else { + return true; + } +} + +bool PartitioningCtx::shouldNodeRunInTensorRT(torch::jit::Node* n) { + auto iter = node_executor_decision_map.find(n); + auto decision = NodeExecutorDecision::kUNKNOWN; + if (iter != node_executor_decision_map.end()) { + decision = iter->second; + } + + if (decision == NodeExecutorDecision::kCONVERT) { + return true; + } else { + return false; + } +} + +std::vector PartitioningCtx::getNodesRunInTorch() { + std::vector nodes_run_in_torch; + for (auto i : node_executor_decision_map) { + if (i.second != NodeExecutorDecision::kCONVERT) { + nodes_run_in_torch.push_back(i.first); + } + } + return nodes_run_in_torch; +} + +std::ostream& operator<<(std::ostream& os, const NodeExecutorDecision& format) { + switch (format) { + case NodeExecutorDecision::kUNSUPPORTED: + return os << "to run torch due to lack of converter support"; + case NodeExecutorDecision::kOPERATOR_FALLBACK: + return os << "to run torch due to user expectily requesting op kind runs in torch"; + case NodeExecutorDecision::kMODULE_FALLBACK: + return os << "to run torch due to being a member of a module user has requested to run in torch"; + case NodeExecutorDecision::kMIN_BLOCK_FALLBACK: + return os << "to run torch due owning block not large enough to exceed user specified min_block_size"; + case NodeExecutorDecision::kNON_TENSOR: + return os << "to run torch due to producing or consuming non-tensor values"; + case NodeExecutorDecision::kCONVERT: + return os << "to run in tensorrt"; + case NodeExecutorDecision::kUNKNOWN: + default: + return os << "unknown node executor decision"; + } +} + +std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g) { + os << "Partitioned Graph: ["; + for (auto b : g) { + os << b; + } + os << "]"; + return os; +} + +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt diff --git a/core/partitioning/partitioningctx/PartitioningCtx.h b/core/partitioning/partitioningctx/PartitioningCtx.h new file mode 100644 index 0000000000..ed8e705be5 --- /dev/null +++ b/core/partitioning/partitioningctx/PartitioningCtx.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include + +#include "core/partitioning/partitioninginfo/PartitioningInfo.h" +#include "core/partitioning/segmentedblock/SegmentedBlock.h" + +namespace torch_tensorrt { +namespace core { +namespace partitioning { + +enum NodeExecutorDecision { + /// Node is not supported by TensorRT + kUNSUPPORTED, + /// Node is explicitly forced to fallback to Pytorch due to operator fallback + kOPERATOR_FALLBACK, + /// Node is explicitly forced to fallback to Pytorch due to module fallback + kMODULE_FALLBACK, + /// This node is in a TRT segment which does not satisfy min_block_size + /// and hence is forced to fallback. + kMIN_BLOCK_FALLBACK, + /// This node produces/consumes non-tensor inputs + kNON_TENSOR, + /// This node is going to be converted + kCONVERT, + /// Sentinel + kUNKNOWN, +}; + +std::ostream& operator<<(std::ostream& os, const NodeExecutorDecision& format); + +typedef std::unordered_map NodeExecutorDecisionMap; + +typedef std::vector PartitionedGraph; + +std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g); + +struct UsageInfo { + size_t produce_id; // id of segmented block which contains a raw value of a given torch::jit::Value + std::vector torch_use_id; // ids of segmented blocks which are of type Pytorch + std::vector tensorrt_use_id; // ids of segmented blocks which are of type TensorRT +}; + +struct PartitioningCtx { + // TODO: Make the set a part of settings not stand alone + PartitioningInfo settings; + // records all the original blocks topologically in the module + std::vector original_blocks; + // mapping: node=> execution status + NodeExecutorDecisionMap node_executor_decision_map; + // LUT of the segmented blocks for each blocks in the module + std::unordered_map partitioned_blocks; + std::unordered_set forced_fallback_ops; + + PartitioningCtx(torch::jit::Block* b, PartitioningInfo info); + void setNodeExecutorDecision(torch::jit::Node* n, NodeExecutorDecision decision); + bool shouldNodeRunInTorch(torch::jit::Node* n); + bool shouldNodeRunInTensorRT(torch::jit::Node* n); + std::vector getNodesRunInTorch(); + + private: + void _load_nodes_into_decision_map(torch::jit::Block* b); +}; + +std::ostream& operator<<(std::ostream& os, const PartitioningCtx& s); + +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt diff --git a/core/partitioning/partitioninginfo/BUILD b/core/partitioning/partitioninginfo/BUILD new file mode 100644 index 0000000000..74e34d134b --- /dev/null +++ b/core/partitioning/partitioninginfo/BUILD @@ -0,0 +1,39 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_pkg//:pkg.bzl", "pkg_tar") + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "use_pre_cxx11_abi", + values = { + "define": "abi=pre_cxx11_abi", + }, +) + +cc_library( + name = "partitioninginfo", + srcs = [ + "PartitioningInfo.cpp", + ], + hdrs = [ + "PartitioningInfo.h", + ], + deps = [ + "//core/util:prelude", + "//core/ir", + "//core/conversion", + "//core/lowering", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + alwayslink = True, +) + +pkg_tar( + name = "include", + srcs = [ + "PartitioningInfo.h", + ], + package_dir = "core/partitioning/partitioninginfo", +) diff --git a/core/partitioning/partitioninginfo/CMakeLists.txt b/core/partitioning/partitioninginfo/CMakeLists.txt new file mode 100644 index 0000000000..86c7388daf --- /dev/null +++ b/core/partitioning/partitioninginfo/CMakeLists.txt @@ -0,0 +1,12 @@ +set(sub_lib_name "partitioninginfo") + +target_sources(${lib_name} + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningInfo.cpp" +) + +set(HEADER_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/PartitioningInfo.h" +) + +# Install headers +install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/${sub_lib_name}") diff --git a/core/partitioning/PartitionInfo.cpp b/core/partitioning/partitioninginfo/PartitioningInfo.cpp similarity index 82% rename from core/partitioning/PartitionInfo.cpp rename to core/partitioning/partitioninginfo/PartitioningInfo.cpp index 59e29a9bf1..16bdd7b9a7 100644 --- a/core/partitioning/PartitionInfo.cpp +++ b/core/partitioning/partitioninginfo/PartitioningInfo.cpp @@ -2,13 +2,13 @@ #include #include -#include "core/partitioning/PartitionInfo.h" +#include "core/partitioning/partitioninginfo/PartitioningInfo.h" namespace torch_tensorrt { namespace core { namespace partitioning { // clang-format off -std::ostream& operator<<(std::ostream& os, const PartitionInfo& s) { +std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s) { os << "Settings requested for Torch Fallback:" \ << "\n \"enabled\": "; if (s.enabled) { diff --git a/core/partitioning/PartitionInfo.h b/core/partitioning/partitioninginfo/PartitioningInfo.h similarity index 67% rename from core/partitioning/PartitionInfo.h rename to core/partitioning/partitioninginfo/PartitioningInfo.h index dc63597912..8eb052e0fa 100644 --- a/core/partitioning/PartitionInfo.h +++ b/core/partitioning/partitioninginfo/PartitioningInfo.h @@ -4,18 +4,21 @@ #include #include +#include "core/ir/ir.h" + namespace torch_tensorrt { namespace core { namespace partitioning { -struct PartitionInfo { +struct PartitioningInfo { + ir::CollectionInputSpecMap collection_input_spec_map; bool enabled = false; uint64_t min_block_size = 1; std::vector forced_fallback_operators; bool truncate_long_and_double; }; -std::ostream& operator<<(std::ostream& os, const PartitionInfo& s); +std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s); } // namespace partitioning } // namespace core diff --git a/core/partitioning/segmentedblock/BUILD b/core/partitioning/segmentedblock/BUILD new file mode 100644 index 0000000000..8efe1e6b0a --- /dev/null +++ b/core/partitioning/segmentedblock/BUILD @@ -0,0 +1,39 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@rules_pkg//:pkg.bzl", "pkg_tar") + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "use_pre_cxx11_abi", + values = { + "define": "abi=pre_cxx11_abi", + }, +) + +cc_library( + name = "segmentedblock", + srcs = [ + "SegmentedBlock.cpp", + ], + hdrs = [ + "SegmentedBlock.h", + ], + deps = [ + "//core/util:prelude", + "//core/ir", + "//core/conversion", + "//core/lowering", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), + alwayslink = True, +) + +pkg_tar( + name = "include", + srcs = [ + "SegmentedBlock.h", + ], + package_dir = "core/partitioning/segmentedblock", +) diff --git a/core/partitioning/segmentedblock/CMakeLists.txt b/core/partitioning/segmentedblock/CMakeLists.txt new file mode 100644 index 0000000000..ad6d9ee875 --- /dev/null +++ b/core/partitioning/segmentedblock/CMakeLists.txt @@ -0,0 +1,12 @@ +set(sub_lib_name "segmentedblock") + +target_sources(${lib_name} + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.cpp" +) + +set(HEADER_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/SegmentedBlock.h" +) + +# Install headers +install(FILES ${HEADER_FILES} DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/torch_tensorrt/core/partitioning/${sub_lib_name}") diff --git a/core/partitioning/SegmentedBlock.cpp b/core/partitioning/segmentedblock/SegmentedBlock.cpp similarity index 100% rename from core/partitioning/SegmentedBlock.cpp rename to core/partitioning/segmentedblock/SegmentedBlock.cpp diff --git a/core/partitioning/SegmentedBlock.h b/core/partitioning/segmentedblock/SegmentedBlock.h similarity index 98% rename from core/partitioning/SegmentedBlock.h rename to core/partitioning/segmentedblock/SegmentedBlock.h index f7d8a0b612..0e04237f63 100644 --- a/core/partitioning/SegmentedBlock.h +++ b/core/partitioning/segmentedblock/SegmentedBlock.h @@ -5,7 +5,6 @@ #include "NvInfer.h" #include "core/ir/ir.h" -#include "core/partitioning/PartitionInfo.h" #include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index f940c87751..514681a088 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -1,9 +1,10 @@ -#include "core/partitioning/shape_analysis.h" -#include -#include "core/util/prelude.h" +#include "ATen/ATen.h" #include "torch/csrc/jit/api/module.h" #include "torch/csrc/jit/passes/constant_pooling.h" +#include "core/partitioning/partitioning.h" +#include "core/util/prelude.h" + namespace torch_tensorrt { namespace core { namespace partitioning { @@ -61,7 +62,7 @@ std::unordered_map generateRandomI void getSegmentsOutputByRunning( SegmentedBlock& seg_block, std::unordered_map& ivalues_maps, - const PartitionInfo& partition_info) { + const PartitioningInfo& partitioning_info) { // create a module to run the graph auto g = seg_block.g(); auto copy_g = g->copy(); @@ -151,13 +152,13 @@ void getSegmentsOutputByRunning( // shape inference auto cur_ivalue = ivalues_maps[i]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); - if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { + if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { TORCHTRT_THROW_ERROR( "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); - } else if (partition_info.truncate_long_and_double && t == at::kLong) { + } else if (partitioning_info.truncate_long_and_double && t == at::kLong) { cur_ivalue = cur_ivalue.toTensor().to(at::kInt); LOG_WARNING("Truncating graph input type from at::kLong to at::kInt"); - } else if (partition_info.truncate_long_and_double && t == at::kDouble) { + } else if (partitioning_info.truncate_long_and_double && t == at::kDouble) { cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); } @@ -180,14 +181,11 @@ void getSegmentsOutputByRunning( seg_block.register_intypes(input_types); } -void runShapeAnalysis( - std::vector& segmented_blocks, - std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info) { +void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) { // register every segment's input shape, and it's running output IValues - for (auto& seg_block : segmented_blocks) { + for (auto& seg_block : ctx->partitioned_blocks[block]) { torch::jit::ConstantPooling(seg_block.g()); - getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info); + getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings); } return; } diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h deleted file mode 100644 index 780449d514..0000000000 --- a/core/partitioning/shape_analysis.h +++ /dev/null @@ -1,20 +0,0 @@ -#include "core/ir/ir.h" -#include "core/partitioning/SegmentedBlock.h" -#include "torch/csrc/jit/ir/ir.h" - -namespace torch_tensorrt { -namespace core { -namespace partitioning { - -std::unordered_map generateRandomInputs( - std::unordered_map>& input_ranges, - std::unordered_map>>& input_types); - -void runShapeAnalysis( - std::vector& segmented_blocks, - std::unordered_map& ivalues_maps, - const PartitionInfo& partition_info); - -} // namespace partitioning -} // namespace core -} // namespace torch_tensorrt diff --git a/core/partitioning/stitching.cpp b/core/partitioning/stitching.cpp new file mode 100644 index 0000000000..6ed5a27463 --- /dev/null +++ b/core/partitioning/stitching.cpp @@ -0,0 +1,151 @@ +#include "ATen/ATen.h" +#include "torch/csrc/jit/api/module.h" +#include "torch/csrc/jit/ir/ir_views.h" + +#include "core/partitioning/partitioning.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace partitioning { + +void addSegmentedBlockToGraph( + std::shared_ptr& g, + partitioning::SegmentedBlock& seg, + std::unordered_map& old_to_new_g) { + // old_to_new_g contains: original global graph value => new global graph value, + // mini_to_new_g: mini graph value -> new graph value + std::unordered_map mini_to_new_g; + size_t input_idx = 0; + if (seg.target() == partitioning::SegmentedBlock::kTensorRT && g->inputs().size() > 0) { + if (g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { + auto self = g->insertInput(0, "self_1"); + self->setType(seg.inputs()[0]->type()); + } + mini_to_new_g[seg.inputs()[input_idx++]] = g->inputs()[0]; + } + + for (auto& raw_input : seg.raw_inputs()) { + if (old_to_new_g.count(raw_input)) { + mini_to_new_g[seg.inputs()[input_idx++]] = old_to_new_g[raw_input]; + } + } + + for (const auto n : seg.nodes()) { + util::cloneNode(n, g, mini_to_new_g); + } + + // original graph value => new global graph value + for (size_t i = 0; i < seg.raw_outputs().size(); ++i) { + old_to_new_g[seg.raw_outputs()[i]] = mini_to_new_g[seg.outputs()[i]]; + } + size_t offset = seg.target() == partitioning::SegmentedBlock::kTensorRT ? 1 : 0; + for (size_t i = 0; i < seg.raw_inputs().size(); ++i) { + if (!old_to_new_g.count(seg.raw_inputs()[i])) { + old_to_new_g[seg.raw_inputs()[i]] = mini_to_new_g[seg.inputs()[i + offset]]; + } + } + + return; +} + +void addIfBlockToGraph( + std::shared_ptr& new_g, + torch::jit::Node* if_node, + const std::vector& graph_and_mappings, + std::unordered_map& old_to_new_g) { + torch::jit::IfView if_view(if_node); + + // create a new if node in new_g and add corresponding inputs + auto new_if = new_g->insertNode(new_g->create(torch::jit::prim::If, {}, 0)); + new_if->addInput(util::getOrAddInputForValue(if_view.cond(), new_g, old_to_new_g)); + + // iterate over all blocks and add them to new created prim::If + for (auto graph_and_mapping : graph_and_mappings) { + auto new_if_block = new_if->addBlock(); + auto cur_block_graph = graph_and_mapping.first; + auto cur_block_mapping = graph_and_mapping.second; + std::unordered_map block_graph_to_new_g; + for (auto& i : cur_block_mapping) { + // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then + // it's mini graph's input + if (old_to_new_g.count(i.first)) { + block_graph_to_new_g[i.second] = old_to_new_g[i.first]; + } + } + + auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; + new_if_block->cloneFrom(cur_block_graph->block(), env); + if (cur_block_graph->inputs().size() && + cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { + auto self = new_g->insertInput(0, "self_1"); + self->setType(cur_block_graph->inputs()[0]->type()); + } + block_graph_to_new_g[cur_block_graph->inputs()[0]] = new_g->inputs()[0]; + } + for (int i = cur_block_graph->inputs().size() - 1; i >= 0; --i) { + new_if_block->inputs()[i]->replaceAllUsesWith(block_graph_to_new_g[cur_block_graph->inputs()[i]]); + new_if_block->eraseInput(i); + } + } + for (auto ov : if_view.outputs()) { + auto no = new_if->addOutput(); + old_to_new_g[ov] = no; + no->copyMetadata(ov); + } + return; +} + +GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block) { + auto new_g = std::make_shared(); + + // the mapping from lowering graph => fallback global graph + std::unordered_map old_to_new_g; + for (auto input : block->inputs()) { + util::getOrAddInputForValue(input, new_g, old_to_new_g); + } + + for (auto seg_block : ctx->partitioned_blocks[block]) { + LOG_INFO("Block segment:" << seg_block); + if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { + addSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); + } else { + if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { + auto if_node = seg_block.raw_nodes()[0]; + + // convert the 2 blocks in prim::if and get the converted graph with mappings + std::vector graph_and_mappings; + for (auto cur_block : if_node->blocks()) { + graph_and_mappings.push_back(stitch(ctx, cur_block)); + } + addIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); + + } else { + addSegmentedBlockToGraph(new_g, seg_block, old_to_new_g); + } + } + } + + if (block->outputs().size() > 1) { + std::vector fallback_graph_vector; + for (auto& output : block->outputs()) { + if (old_to_new_g.count(output)) { + fallback_graph_vector.push_back(old_to_new_g[output]); + } + } + torch::jit::ArrayRef fallback_graph_outputs(fallback_graph_vector); + auto return_tuple_node = new_g->createTuple(fallback_graph_outputs); + new_g->block()->appendNode(return_tuple_node); + // Set the output as the produced tuple + new_g->registerOutput(return_tuple_node->outputs()[0]); + } else { + if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) { + new_g->registerOutput(old_to_new_g[block->outputs()[0]]); + } + } + return {new_g, old_to_new_g}; +} +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index cfbc228396..3d7d9b15d3 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -121,10 +121,10 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { "require_full_compilation is enabled however the list of modules to run in torch is not empty (Found " << external.torch_executed_modules.size() << " modules)"); - internal.partition_info.enabled = !external.require_full_compilation; - internal.partition_info.min_block_size = external.min_block_size; - internal.partition_info.forced_fallback_operators = std::move(external.torch_executed_ops); - internal.partition_info.truncate_long_and_double = external.truncate_long_and_double; + internal.partitioning_info.enabled = !external.require_full_compilation; + internal.partitioning_info.min_block_size = external.min_block_size; + internal.partitioning_info.forced_fallback_operators = std::move(external.torch_executed_ops); + internal.partitioning_info.truncate_long_and_double = external.truncate_long_and_double; internal.lower_info.forced_fallback_modules = std::move(external.torch_executed_modules); switch (external.device.device_type) { diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 96fef793fd..1721ffd6c9 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -313,10 +313,10 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.device.gpu_id = device.gpu_id; info.convert_info.engine_settings.device.dla_core = device.dla_core; info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback; - info.partition_info.enabled = torch_fallback.enabled; - info.partition_info.min_block_size = torch_fallback.min_block_size; - info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators; - info.partition_info.truncate_long_and_double = truncate_long_and_double; + info.partitioning_info.enabled = torch_fallback.enabled; + info.partitioning_info.min_block_size = torch_fallback.min_block_size; + info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators; + info.partitioning_info.truncate_long_and_double = truncate_long_and_double; info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules; info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double; diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index b33685a647..75ae818905 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -75,6 +75,10 @@ lowering_test( name = "test_silu_to_sigmoid_multiplication", ) +lowering_test( + name = "test_unpack_hardsigmoid", +) + lowering_test( name = "test_unpack_hardswish", ) @@ -98,6 +102,7 @@ test_suite( ":test_remove_detach_pass", ":test_remove_dropout_pass", ":test_remove_unnecessary_casts", + ":test_unpack_hardsigmoid", ":test_unpack_hardswish", ":test_unpack_reduce_ops", ":test_view_to_reshape_pass", diff --git a/tests/core/lowering/test_module_fallback_passes.cpp b/tests/core/lowering/test_module_fallback_passes.cpp index e6eb098079..5f4ac5f0c2 100644 --- a/tests/core/lowering/test_module_fallback_passes.cpp +++ b/tests/core/lowering/test_module_fallback_passes.cpp @@ -100,7 +100,7 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) { std::vector input_ranges{torch_tensorrt::core::ir::Input({1, 1, 16, 16})}; torch_tensorrt::core::CompileSpec cfg(input_ranges); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; cfg.lower_info.forced_fallback_modules.push_back("ModuleFallbackSub"); auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); diff --git a/tests/core/lowering/test_unpack_hardsigmoid.cpp b/tests/core/lowering/test_unpack_hardsigmoid.cpp new file mode 100644 index 0000000000..f8206511be --- /dev/null +++ b/tests/core/lowering/test_unpack_hardsigmoid.cpp @@ -0,0 +1,87 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, UnpackHardSigmoid) { + std::string source_graph = R"IR( + graph(%input): + %result = aten::hardsigmoid(%input) + return (%result))IR"; + + std::string target_graph = R"IR( + graph(%x.1): + %22 : float = prim::Constant[value=0.5]() + %3 : int = prim::Constant[value=6]() + %5 : int = prim::Constant[value=1]() + %10 : int = prim::Constant[value=0]() + %4 : Tensor = aten::div(%x.1, %3) + %9 : Tensor = aten::add(%4, %22, %5) + %21 : Tensor = aten::clamp(%9, %10, %5) + return (%21))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + auto in = at::rand({10, 100}, {at::kCUDA}); + auto sg_params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto sg_results = torch_tensorrt::tests::util::RunGraph(sg, sg_params, {in}); + + torch_tensorrt::core::lowering::passes::UnpackHardSigmoid(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); + + in = at::clone(in); + auto tg_params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto tg_results = torch_tensorrt::tests::util::RunGraph(tg, tg_params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6)); +} + +TEST(LoweringPasses, UnpackHardSigmoidInPlace) { + std::string source_graph = R"IR( + graph(%input): + %result = aten::hardsigmoid_(%input) + return (%result))IR"; + + std::string target_graph = R"IR( + graph(%x.1): + %22 : float = prim::Constant[value=0.5]() + %3 : int = prim::Constant[value=6]() + %5 : int = prim::Constant[value=1]() + %10 : int = prim::Constant[value=0]() + %4 : Tensor = aten::div(%x.1, %3) + %9 : Tensor = aten::add(%4, %22, %5) + %21 : Tensor = aten::clamp(%9, %10, %5) + return (%21))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + + auto in = at::rand({10, 100}, {at::kCUDA}); + auto sg_params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto sg_results = torch_tensorrt::tests::util::RunGraph(sg, sg_params, {in}); + + torch_tensorrt::core::lowering::passes::UnpackHardSigmoid(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); + + in = at::clone(in); + auto tg_params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto tg_results = torch_tensorrt::tests::util::RunGraph(tg, tg_params, {in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6)); +} diff --git a/tests/core/lowering/test_view_to_reshape_pass.cpp b/tests/core/lowering/test_view_to_reshape_pass.cpp index d1f787bc10..a6254bccde 100644 --- a/tests/core/lowering/test_view_to_reshape_pass.cpp +++ b/tests/core/lowering/test_view_to_reshape_pass.cpp @@ -66,8 +66,8 @@ TEST(LoweringPasses, ViewToReshapeResultsCorrectly) { std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({2, 3, 4, 5})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; - cfg.partition_info.forced_fallback_operators.push_back("aten::permute"); + cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("aten::permute"); torch::jit::script::Module mod(c10::QualifiedName("module")); diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 424fac86e0..ba336db663 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -34,7 +34,7 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { std::vector inputs{torch_tensorrt::core::ir::Input({3, 3, 16, 16})}; auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto new_g = new_mod.get_method("forward").graph(); @@ -65,8 +65,8 @@ TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})}; auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; - cfg.partition_info.forced_fallback_operators.push_back("prim::ListConstruct"); + cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("prim::ListConstruct"); auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); diff --git a/tests/core/partitioning/test_fallback_graph_output.cpp b/tests/core/partitioning/test_fallback_graph_output.cpp index 3da717074a..f6ce657ae3 100644 --- a/tests/core/partitioning/test_fallback_graph_output.cpp +++ b/tests/core/partitioning/test_fallback_graph_output.cpp @@ -28,8 +28,8 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) { std::vector input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})}; torch_tensorrt::core::CompileSpec cfg(input_ranges); - cfg.partition_info.enabled = true; - cfg.partition_info.forced_fallback_operators.push_back("aten::add"); + cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("aten::add"); auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); @@ -58,8 +58,8 @@ TEST(Partitioning, ComputeMobileNetFallbackGraphCorrectly) { std::vector input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})}; auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(input_ranges); - cfg.partition_info.enabled = true; - cfg.partition_info.forced_fallback_operators.push_back("aten::hardtanh"); + cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("aten::hardtanh"); auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); diff --git a/tests/core/partitioning/test_loading_model.cpp b/tests/core/partitioning/test_loading_model.cpp index 057aaff2d8..b42368fe3e 100644 --- a/tests/core/partitioning/test_loading_model.cpp +++ b/tests/core/partitioning/test_loading_model.cpp @@ -28,7 +28,7 @@ TEST(Partitioning, ComputeResNet50FallbackGraphCorrectly) { std::vector input_ranges{torch_tensorrt::core::ir::Input({1, 3, 224, 224})}; torch_tensorrt::core::CompileSpec cfg(input_ranges); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); diff --git a/tests/core/partitioning/test_loop_fallback.cpp b/tests/core/partitioning/test_loop_fallback.cpp index 83556b5512..5f6bc2ae4d 100644 --- a/tests/core/partitioning/test_loop_fallback.cpp +++ b/tests/core/partitioning/test_loop_fallback.cpp @@ -25,7 +25,7 @@ TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) { std::vector input_ranges{torch_tensorrt::core::ir::Input({1, 10})}; torch_tensorrt::core::CompileSpec cfg(input_ranges); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); @@ -53,7 +53,7 @@ TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) { std::vector input_ranges{torch_tensorrt::core::ir::Input({1, 10})}; torch_tensorrt::core::CompileSpec cfg(input_ranges); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index 30656a3d9e..950859e524 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -60,10 +60,10 @@ TEST(Partitioning, ResolveNonTensorInputsForIFBlockCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({3, 4})); inputs.push_back(torch_tensorrt::core::ir::Input({3, 4})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; - cfg.partition_info.forced_fallback_operators.push_back("aten::sub"); + cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("aten::sub"); cfg.convert_info.engine_settings.truncate_long_and_double = true; - cfg.partition_info.truncate_long_and_double = true; + cfg.partitioning_info.truncate_long_and_double = true; torch::jit::script::Module mod(c10::QualifiedName("module")); @@ -109,8 +109,8 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({1, 3, 16, 16})); inputs.push_back(torch_tensorrt::core::ir::Input({16, 3, 3, 3})); @@ -123,9 +123,10 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - std::unordered_map fallback_nodes; + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); + ctx.partitioned_blocks.begin()->second; int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { @@ -168,8 +169,8 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({1, 3, 16, 16})); inputs.push_back(torch_tensorrt::core::ir::Input({16, 6, 3, 3})); @@ -182,9 +183,11 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - std::unordered_map fallback_nodes; + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); + ctx.partitioned_blocks.begin()->second; int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { @@ -244,7 +247,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({1, 3, 16, 16})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; torch::jit::script::Module mod(c10::QualifiedName("module")); auto self = g->insertInput(0, "self_1"); @@ -361,8 +364,8 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { g->registerOutput(get_ins_node->output()); g->registerOutput(get_outs_node->output()); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); @@ -374,9 +377,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - std::unordered_map fallback_nodes; - auto segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index bf8a36d081..8d47af553e 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -6,9 +6,14 @@ #include "torch/script.h" #include "torch_tensorrt/torch_tensorrt.h" +namespace torch_tensorrt { +namespace core { +namespace partitioning { +namespace tests { + bool checkSegmentedBlockNumber( - torch_tensorrt::core::partitioning::PartitionedGraph& segmented_blocks, - torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget target, + PartitionedGraph& segmented_blocks, + SegmentedBlock::SegmentedBlockTarget target, int target_count) { int64_t cnt = 0; for (auto& seg_block : segmented_blocks) { @@ -27,7 +32,7 @@ bool checkSegmentedBlockNumber( } bool checkSegmentedBlockNodesMapping( - std::vector& segmented_blocks, + std::vector& segmented_blocks, std::shared_ptr g, std::vector> nodes_index) { std::vector graph_nodes; @@ -71,17 +76,15 @@ TEST(Partitioning, SegmentSequentialModelCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3}, {4}})); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3}, {4}})); } TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { @@ -106,18 +109,16 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - partition_info.min_block_size = 3; - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}})); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4}})); } TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { @@ -146,18 +147,16 @@ TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - partition_info.min_block_size = 3; - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2, 3}, {4, 5, 6, 7}})); } TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { @@ -182,18 +181,16 @@ TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - partition_info.forced_fallback_operators.push_back("aten::relu"); - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 2)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0}, {1}, {2}, {3}, {4}})); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators.push_back("aten::relu"); + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0}, {1}, {2}, {3}, {4}})); } TEST(Partitioning, SegmentBranchModelCorrectly) { @@ -219,17 +216,15 @@ TEST(Partitioning, SegmentBranchModelCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2)); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3, 4, 5, 6}})); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 2)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3, 4, 5, 6}})); } TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { @@ -255,18 +250,16 @@ TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - partition_info.min_block_size = 3; - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4, 5, 6}})); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.min_block_size = 3; + PartitioningCtx ctx(g->block(), partitioning_info); + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 1)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 1)); + ASSERT_TRUE(checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1, 2}, {3, 4, 5, 6}})); } TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { @@ -296,16 +289,21 @@ TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); + LOG_GRAPH(*g); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; - partition_info.forced_fallback_operators.push_back("aten::relu"); - std::unordered_map fallback_nodes; - torch_tensorrt::core::partitioning::PartitionedGraph segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); - ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); + PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators.push_back("aten::relu"); + PartitioningCtx ctx(g->block(), partitioning_info); + + segmentGraph(&ctx, g->block()); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTensorRT, 3)); + ASSERT_TRUE(checkSegmentedBlockNumber(ctx.partitioned_blocks.begin()->second, SegmentedBlock::kTorch, 2)); ASSERT_TRUE( - checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 2)); - ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1}, {2}, {3}, {4}, {5, 6}})); + checkSegmentedBlockNodesMapping(ctx.partitioned_blocks.begin()->second, g, {{0, 1}, {2}, {3}, {4}, {5, 6}})); } + +} // namespace tests +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 98b375f121..87c42c0e47 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -48,8 +48,8 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16})); inputs.push_back(torch_tensorrt::core::ir::Input({32, 3, 3, 3})); @@ -66,9 +66,10 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); + + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, @@ -101,8 +102,8 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); - torch_tensorrt::core::partitioning::PartitionInfo partition_info; - partition_info.enabled = true; + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16})); inputs.push_back(torch_tensorrt::core::ir::Input({32, 3, 3, 3})); @@ -117,9 +118,10 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - std::unordered_map fallback_nodes; - std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); + + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, diff --git a/tests/core/partitioning/test_stitched_graph.cpp b/tests/core/partitioning/test_stitched_graph.cpp index 61c5b58552..4332668506 100644 --- a/tests/core/partitioning/test_stitched_graph.cpp +++ b/tests/core/partitioning/test_stitched_graph.cpp @@ -75,7 +75,7 @@ TEST(Partitioning, StitchSequentialModelSegmentedBlockCorrectly) { std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto fallback_g = new_mod.get_method("forward").graph(); ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g)); @@ -133,7 +133,7 @@ TEST(Partitioning, StitchBranchModelSegmentedBlockCorrectly) { std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto fallback_g = new_mod.get_method("forward").graph(); ASSERT_TRUE(checkAllInputsExistInStitchedGraph(fallback_g)); diff --git a/tests/core/partitioning/test_tensorrt_conversion.cpp b/tests/core/partitioning/test_tensorrt_conversion.cpp index 8b42f95e24..41431c76db 100644 --- a/tests/core/partitioning/test_tensorrt_conversion.cpp +++ b/tests/core/partitioning/test_tensorrt_conversion.cpp @@ -57,7 +57,7 @@ TEST(Partitioning, ConvertSequentialModelSegmentedBlockCorrectly) { std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; torch::jit::script::Module mod(c10::QualifiedName("module")); auto self = g->insertInput(0, "self_1"); @@ -116,7 +116,7 @@ TEST(Partitioning, ConvertBranchModelSegmentedBlockCorrectly) { std::vector inputs; inputs.push_back(torch_tensorrt::core::ir::Input({3, 3, 16, 16})); torch_tensorrt::core::CompileSpec cfg(inputs); - cfg.partition_info.enabled = true; + cfg.partitioning_info.enabled = true; torch::jit::script::Module mod(c10::QualifiedName("module")); auto self = g->insertInput(0, "self_1");