diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index eb8c86de50..2e384b25ab 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -453,6 +453,7 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) { registerSegmentsOutputs(ctx, block); // run shape analysis on each segmented block + LOG_DEBUG("Running shape analysis for segmented graphs"); runShapeAnalysis(ctx, block, example_tensor_map); } } diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index b49a0efc72..2220b59784 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -1,3 +1,4 @@ +#include #include "ATen/ATen.h" #include "torch/csrc/jit/api/module.h" #include "torch/csrc/jit/passes/constant_pooling.h" @@ -57,6 +58,61 @@ std::unordered_map generateRandomI return ivalue_map; } +torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) { + std::queue q; + q.push(val); + std::unordered_set visited; + while (!q.empty()) { + auto cur_val = q.front(); + q.pop(); + auto node = cur_val->node(); + if ((node->kind().toQualString() == std::string("aten::to")) && + ((node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) || + (node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType))) { + return node; + } + if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) { + visited.insert(node); + for (auto input : node->inputs()) { + q.push(input); + } + } + } + return nullptr; +} + +torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) { + auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index]; + auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index]; + torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value); + auto g = seg_block.g(); + // if we can find upstream aten::to node, we use it's parameters for creating new cast node + if (cast_node) { + std::unordered_map value_map; + value_map.insert({cast_node->inputs()[0], cast_subgraph_value}); + if (!is_input) { + // if this value is output, we need to cast it to int32 + auto const_val = g->insertConstant(3); + if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) { + value_map.insert({cast_node->inputs()[2], const_val}); + } else { + value_map.insert({cast_node->inputs()[1], const_val}); + } + } + auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, g, value_map); }; + cast_node = g->createClone(cast_node, env); + // auto cast_node = g->prependNode(g->createClone(cast_node, env)); + } else { + // if there is no explicit cast aten::to operation, we need to create a node + auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3); + auto const_zero = g->insertConstant(0); + const_zero->setType(torch::jit::BoolType::get()); + auto none_val = g->insertNode(g->createNone())->output(); + cast_node = g->create(torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val}); + } + return cast_node; +} + void getSegmentsOutputByRunning( SegmentedBlock& seg_block, std::unordered_map& ivalues_maps, @@ -142,16 +198,45 @@ void getSegmentsOutputByRunning( ivalues_maps[output] = jit_results[idx++]; } + // auto int64 <=> int32 conversion + if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) { + // First, check if there is Int64 input + for (size_t i = 0; i < seg_block.inputs().size(); ++i) { + if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { + auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; + at::ScalarType t = cur_ivalue.toTensor().scalar_type(); + if (t == at::kLong) { + // we add a cast operation to cast the type to Int64 + auto cast_node = createCastNode(seg_block, i, true); + seg_block.g()->prependNode(cast_node); + seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); + } + } + } + for (size_t i = 0; i < seg_block.outputs().size(); ++i) { + if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) { + auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]]; + at::ScalarType t = cur_ivalue.toTensor().scalar_type(); + if (t == at::kLong) { + auto cast_node = createCastNode(seg_block, i, false); + seg_block.g()->appendNode(cast_node); + seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); + } + } + } + } + // set input shape for each segmented block so we wil use it in conversion process std::vector input_shapes; std::vector input_types; - for (auto& i : seg_block.raw_inputs()) { - if (ivalues_maps[i].isTensor()) { + for (size_t i = 0; i < seg_block.inputs().size(); ++i) { + if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { // set the input_shape and data_type // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for // shape inference - auto cur_ivalue = ivalues_maps[i]; + auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); + if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { TORCHTRT_THROW_ERROR( "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 83722b4271..5aba817bd6 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -40,6 +40,10 @@ partitioning_test( name = "test_resolve_nontensor_inputs", ) +partitioning_test( + name = "test_type_auto_conversion", +) + cc_test( name = "test_loading_model", srcs = ["test_loading_model.cpp"], @@ -112,5 +116,6 @@ test_suite( ":test_shape_analysis", ":test_stitched_graph", ":test_tensorrt_conversion", + ":test_type_auto_conversion", ], ) diff --git a/tests/core/partitioning/test_type_auto_conversion.cpp b/tests/core/partitioning/test_type_auto_conversion.cpp new file mode 100644 index 0000000000..d7b7e2391f --- /dev/null +++ b/tests/core/partitioning/test_type_auto_conversion.cpp @@ -0,0 +1,106 @@ +#include +#include "core/partitioning/partitioning.h" +#include "core/util/trt_util.h" +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/script.h" + +bool checkInsertedCastNodeNumber(torch_tensorrt::core::partitioning::SegmentedBlock& seg_block, int target_count) { + int64_t cnt = 0; + for (auto node : seg_block.nodes()) { + if (node->kind().toQualString() == std::string("aten::to")) { + cnt++; + } + } + std::cout << "Found count of " << cnt << " inserted aten::to nodes, (looking for " << target_count + << " aten::to nodes)" << std::endl; + + return target_count == cnt; +} + +TEST(Partitioning, ExplicitNodeAutoConversionCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Tensor): + %2 : int = prim::Constant[value=4]() + %3 : bool = prim::Constant[value=0]() + %4 : NoneType = prim::Constant() + %5 : int = prim::Constant[value=1]() + %7: Tensor = aten::to(%1, %2, %3, %3, %4) + %8 : Tensor = aten::mul(%0, %0) + %9 : Tensor = aten::scatter(%8, %5, %7, %5) + %10 : Tensor = aten::scatter(%7, %5, %7, %5) + %12 : Tensor = aten::add(%10, %10, %5) + return (%9, %12))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get(), true); + + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators = {"aten::scatter"}; + partitioning_info.truncate_long_and_double = true; + std::vector inputs; + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + + std::unordered_map> inputs_map; + std::unordered_map>> input_types; + inputs_map.insert({g->inputs()[0], {inputs[0]}}); + input_types.insert({g->inputs()[0], {{at::kFloat}}}); + inputs_map.insert({g->inputs()[1], {inputs[1]}}); + input_types.insert({g->inputs()[1], {{at::kInt}}}); + + auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; + + for (auto& seg_block : segmented_blocks) { + LOG_DEBUG(seg_block << " cur seg block"); + } + ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2)); +} + +TEST(Partitioning, ImplicitAutoConversionCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %2 : int = prim::Constant[value=0]() + %4 : int = aten::size(%0, %2) + %6 : Tensor = prim::NumToTensor(%4) + %2 : int = prim::Constant[value=5]() + %7 : int[] = prim::ListConstruct(%2, %2) + %8 : bool = prim::Constant[value=0]() + %9 : Tensor = aten::expand(%6, %7, %8) + + %10 : Tensor = aten::mul(%9, %9) + return (%10))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get(), true); + + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.forced_fallback_operators = {"aten::expand"}; + partitioning_info.truncate_long_and_double = true; + std::vector inputs; + + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + + std::unordered_map> inputs_map; + std::unordered_map>> input_types; + inputs_map.insert({g->inputs()[0], {inputs[0]}}); + input_types.insert({g->inputs()[0], {{at::kFloat}}}); + + auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; + + for (auto& seg_block : segmented_blocks) { + LOG_DEBUG(seg_block << " cur seg block"); + } + ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2)); +}