diff --git a/core/partitioning/partitioninginfo/PartitioningInfo.h b/core/partitioning/partitioninginfo/PartitioningInfo.h index ed7d2033c6..62ad46034f 100644 --- a/core/partitioning/partitioninginfo/PartitioningInfo.h +++ b/core/partitioning/partitioninginfo/PartitioningInfo.h @@ -17,6 +17,7 @@ struct PartitioningInfo { std::vector forced_fallback_operators; bool truncate_long_and_double; ir::Device target_device; + bool cast_int8_inputs = false; std::string getGPUDeviceString() const { return "cuda:" + std::to_string(target_device.gpu_id); diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 4220764dd6..6a648f0063 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) { return nullptr; } -torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) { +torch::jit::Node* createCastNode( + SegmentedBlock& seg_block, + size_t index, + bool is_input, + at::ScalarType dtype, + std::string device, + bool force_create_node = false) { auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index]; auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index]; torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value); auto g = seg_block.g(); // if we can find upstream aten::to node, we use it's parameters for creating new cast node - if (cast_node) { + if (cast_node && !force_create_node) { std::unordered_map value_map; value_map.insert({cast_node->inputs()[0], cast_subgraph_value}); if (!is_input) { // if this value is output, we need to cast it to int32 - auto const_val = g->insertConstant(3); + auto const_val = g->insertConstant(dtype); if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) { value_map.insert({cast_node->inputs()[2], const_val}); } else { @@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i // auto cast_node = g->prependNode(g->createClone(cast_node, env)); } else { // if there is no explicit cast aten::to operation, we need to create a node - auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3); + auto const_type = g->insertConstant(dtype); auto const_zero = g->insertConstant(0); const_zero->setType(torch::jit::BoolType::get()); auto cuda = g->insertConstant(device); @@ -222,27 +228,56 @@ void getSegmentsOutputByRunning( auto target_device = partitioning_info.getGPUDeviceString(); - // auto int64 <=> int32 conversion - if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) { + // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models + if (seg_block.target() == SegmentedBlock::kTorch) { // First, check if there is Int64 input for (size_t i = 0; i < seg_block.inputs().size(); ++i) { if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); - if (t == at::kLong) { + if (t == at::kLong && partitioning_info.truncate_long_and_double) { + LOG_DEBUG( + "Detected graph Long tensor input type during shape analysis, " + << "inserting aten::to cast to Long to ensure this Torch block receives " + << "a Long-type tensor input."); // we add a cast operation to cast the type to Int64 - auto cast_node = createCastNode(seg_block, i, true, target_device); + auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device); + seg_block.g()->prependNode(cast_node); + seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); + } else if (t == at::kByte && partitioning_info.cast_int8_inputs) { + LOG_DEBUG( + "Detected graph Byte tensor input type during shape analysis, " + << "inserting aten::to cast to Byte to ensure this Torch block receives " + << "a Byte-type tensor input."); + // If the input has type Byte, ensure it is casted to the correct type + auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true); seg_block.g()->prependNode(cast_node); seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); } } } + for (size_t i = 0; i < seg_block.outputs().size(); ++i) { if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) { auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); - if (t == at::kLong) { - auto cast_node = createCastNode(seg_block, i, false, target_device); + + // If the output has type Long and truncation was requested, insert truncate + if (t == at::kLong && partitioning_info.truncate_long_and_double) { + LOG_DEBUG( + "Detected graph Long tensor output type during shape analysis, " + << "inserting aten::to cast to Int to ensure the subsequent TensorRT block " + << "receives an Int-type tensor input."); + auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device); + seg_block.g()->appendNode(cast_node); + seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); + } else if (t == at::kByte && partitioning_info.cast_int8_inputs) { + LOG_DEBUG( + "Detected graph Byte tensor output type during shape analysis, " + << "inserting aten::to cast to Int to ensure the subsequent TensorRT block " + << "receives an Int-type tensor input."); + // If the output has type Byte and casting was requested, insert Integer cast + auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true); seg_block.g()->appendNode(cast_node); seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); } @@ -254,11 +289,13 @@ void getSegmentsOutputByRunning( std::vector> input_shapes; std::vector input_types; for (size_t i = 0; i < seg_block.inputs().size(); ++i) { - if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { + auto current_input = seg_block.raw_inputs()[i]; + + if (ivalues_maps[current_input].isTensor()) { // set the input_shape and data_type // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for // shape inference - auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; + auto cur_ivalue = ivalues_maps[current_input]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { @@ -271,10 +308,16 @@ void getSegmentsOutputByRunning( cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); } + c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); if (dtype == c10::nullopt) { TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype()); + } else if (dtype && dtype.value() == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs) { + // Special case to ensure input IValues to TensorRT engine are not Int8 type if the + // model itself is not quantized + cur_ivalue = cur_ivalue.toTensor().to(at::kInt); } + if (cur_ivalue.toTensor().sizes().size() == 0) { // handle Scalar types, which has sizes of [] input_shapes.push_back(util::toVec(util::toDims(c10::List({1})))); @@ -297,6 +340,7 @@ void runShapeAnalysis( const ir::ShapeMode& shape_mode) { // register every segment's input shape, and it's running output IValues for (auto& seg_block : ctx->partitioned_blocks[block]) { + LOG_GRAPH("Running shape analysis on block " << seg_block); torch::jit::ConstantPooling(seg_block.g()); getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode); } diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index d320992a70..835faaed68 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -252,6 +252,7 @@ const std::unordered_map& get_at_trt_type_ma {at::kHalf, nvinfer1::DataType::kHALF}, {at::kInt, nvinfer1::DataType::kINT32}, {at::kChar, nvinfer1::DataType::kINT8}, + {at::kByte, nvinfer1::DataType::kINT8}, {at::kBool, nvinfer1::DataType::kBOOL}}; return at_trt_type_map; } diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 24aba31515..3de2daa14a 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -167,8 +167,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size; internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size; + internal.partitioning_info.cast_int8_inputs = true; + if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) { + internal.partitioning_info.cast_int8_inputs = false; if (external.ptq_calibrator) { internal.convert_info.engine_settings.calibrator = external.ptq_calibrator; } else { diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 489da576e2..9822f47f3b 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -300,11 +300,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } + info.partitioning_info.cast_int8_inputs = true; + if (ptq_calibrator) { info.convert_info.engine_settings.calibrator = ptq_calibrator; + info.partitioning_info.cast_int8_inputs = false; } else { if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != info.convert_info.engine_settings.enabled_precisions.end()) { + info.partitioning_info.cast_int8_inputs = false; info.lower_info.unfreeze_module = true; info.lower_info.disable_cse = true; } @@ -313,10 +317,23 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.convert_info.engine_settings.disable_tf32 = disable_tf32; info.convert_info.engine_settings.refit = refit; info.convert_info.engine_settings.debug = debug; + + // Specify + replicate device settings for phases requiring it info.convert_info.engine_settings.device.device_type = toTRTDeviceType(device.device_type); info.convert_info.engine_settings.device.gpu_id = device.gpu_id; info.convert_info.engine_settings.device.dla_core = device.dla_core; info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback; + + info.lower_info.target_device.device_type = toTRTDeviceType(device.device_type); + info.lower_info.target_device.gpu_id = device.gpu_id; + info.lower_info.target_device.dla_core = device.dla_core; + info.lower_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback; + + info.partitioning_info.target_device.device_type = toTRTDeviceType(device.device_type); + info.partitioning_info.target_device.gpu_id = device.gpu_id; + info.partitioning_info.target_device.dla_core = device.dla_core; + info.partitioning_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback; + info.partitioning_info.enabled = torch_fallback.enabled; info.partitioning_info.min_block_size = torch_fallback.min_block_size; info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators; diff --git a/tests/core/partitioning/test_type_auto_conversion.cpp b/tests/core/partitioning/test_type_auto_conversion.cpp index 28f620b843..03c7b70e38 100644 --- a/tests/core/partitioning/test_type_auto_conversion.cpp +++ b/tests/core/partitioning/test_type_auto_conversion.cpp @@ -107,3 +107,63 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) { } ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2)); } + +TEST(Partitioning, ExplicitNodeAutoInt8ConversionCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %y.1 : Tensor): + + %26 : int = prim::Constant[value=1]() + %21 : bool = prim::Constant[value=0]() + %60 : Device = prim::Constant[value="cuda"]() + %14 : NoneType = prim::Constant() + %3 : int = prim::Constant[value=5]() + %19 : int = prim::Constant[value=0]() + %29 : int = prim::Constant[value=2]() + %13 : int[] = prim::ListConstruct(%3, %3) + %k_.1 : Tensor = aten::ones(%13, %19, %14, %60, %14) + %20 : int[] = prim::ListConstruct(%19) + %k.1 : Tensor = aten::sum(%k_.1, %20, %21, %14) + %x.5 : Tensor = aten::add_(%x.1, %y.1, %26) + %31 : Tensor = aten::mul(%y.1, %29) + %x.9 : Tensor = aten::add_(%x.5, %31, %26) + %x.13 : Tensor = aten::add_(%x.9, %k.1, %26) + %x.17 : Tensor = aten::sub_(%x.13, %k.1, %26) + %x.21 : Tensor = aten::add_(%x.17, %k.1, %26) + %x.25 : Tensor = aten::sub_(%x.21, %k.1, %26) + + return (%x.25))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get(), true); + + torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; + partitioning_info.enabled = true; + partitioning_info.cast_int8_inputs = true; + partitioning_info.forced_fallback_operators = {"aten::ones"}; + partitioning_info.truncate_long_and_double = true; + std::vector inputs; + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); + + std::unordered_map> inputs_map; + std::unordered_map>> input_types; + inputs_map.insert({g->inputs()[0], {inputs[0]}}); + input_types.insert({g->inputs()[0], {{at::kFloat}}}); + inputs_map.insert({g->inputs()[1], {inputs[1]}}); + input_types.insert({g->inputs()[1], {{at::kInt}}}); + + partitioning_info.collection_input_spec_map = inputs_map; + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + ctx.input_types_map = input_types; + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); + auto segmented_blocks = ctx.partitioned_blocks.begin()->second; + + for (auto& seg_block : segmented_blocks) { + LOG_DEBUG(seg_block << " cur seg block"); + } + + // Seeking 1 inserted aten::to converting Byte to Int (%k_.1 is a Byte Tensor) + ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 1)); +}