Skip to content

fix: Properly cast intermediate Int8 tensors to TensorRT Engines in Fallback #1549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/partitioning/partitioninginfo/PartitioningInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct PartitioningInfo {
std::vector<std::string> forced_fallback_operators;
bool truncate_long_and_double;
ir::Device target_device;
bool cast_int8_inputs = false;

std::string getGPUDeviceString() const {
return "cuda:" + std::to_string(target_device.gpu_id);
Expand Down
56 changes: 40 additions & 16 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,18 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
return nullptr;
}

torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
torch::jit::Node* createCastNode(
SegmentedBlock& seg_block,
size_t index,
bool is_input,
std::string device,
bool force_create_node = false) {
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
auto g = seg_block.g();
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
if (cast_node) {
if (cast_node && !force_create_node) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
if (!is_input) {
Expand Down Expand Up @@ -222,29 +227,39 @@ void getSegmentsOutputByRunning(

auto target_device = partitioning_info.getGPUDeviceString();

// auto int64 <=> int32 conversion
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
if (seg_block.target() == SegmentedBlock::kTorch) {
// First, check if there is Int64 input
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true, target_device);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
if (partitioning_info.truncate_long_and_double) {
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true, target_device);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just linter formatting changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I manually made the formatting changes to reduce redundancy of if statements, but they should be functionally equivalent to the previous version

}

for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {

// If the input has type Long and truncation was requested, insert truncate
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
auto cast_node = createCastNode(seg_block, i, false, target_device);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
// If the input has type Byte and truncation was requested, insert Integer cast
auto cast_node = createCastNode(seg_block, i, false, target_device, /*force_create_node=*/true);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
}
}
}
Expand All @@ -254,11 +269,13 @@ void getSegmentsOutputByRunning(
std::vector<std::vector<int64_t>> input_shapes;
std::vector<at::ScalarType> input_types;
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto current_input = seg_block.raw_inputs()[i];

if (ivalues_maps[current_input].isTensor()) {
// set the input_shape and data_type
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
// shape inference
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
auto cur_ivalue = ivalues_maps[current_input];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();

if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
Expand All @@ -271,10 +288,16 @@ void getSegmentsOutputByRunning(
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
}

c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
if (dtype == c10::nullopt) {
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());
} else if (dtype && dtype.value() == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs) {
// Special case to ensure input IValues to TensorRT engine are not Int8 type if the
// model itself is not quantized
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
}

if (cur_ivalue.toTensor().sizes().size() == 0) {
// handle Scalar types, which has sizes of []
input_shapes.push_back(util::toVec(util::toDims(c10::List<int64_t>({1}))));
Expand All @@ -297,6 +320,7 @@ void runShapeAnalysis(
const ir::ShapeMode& shape_mode) {
// register every segment's input shape, and it's running output IValues
for (auto& seg_block : ctx->partitioned_blocks[block]) {
LOG_GRAPH("Running shape analysis on block " << seg_block);
torch::jit::ConstantPooling(seg_block.g());
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode);
}
Expand Down
1 change: 1 addition & 0 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
return at_trt_type_map;
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/compile_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size;
internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size;

internal.partitioning_info.cast_int8_inputs = true;

if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
internal.convert_info.engine_settings.enabled_precisions.end()) {
internal.partitioning_info.cast_int8_inputs = false;
if (external.ptq_calibrator) {
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
} else {
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/csrc/tensorrt_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
}

info.partitioning_info.cast_int8_inputs = true;

if (ptq_calibrator) {
info.convert_info.engine_settings.calibrator = ptq_calibrator;
info.partitioning_info.cast_int8_inputs = false;
} else {
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
info.convert_info.engine_settings.enabled_precisions.end()) {
info.partitioning_info.cast_int8_inputs = false;
info.lower_info.unfreeze_module = true;
info.lower_info.disable_cse = true;
}
Expand All @@ -313,10 +317,23 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
info.convert_info.engine_settings.refit = refit;
info.convert_info.engine_settings.debug = debug;

// Specify + replicate device settings for phases requiring it
info.convert_info.engine_settings.device.device_type = toTRTDeviceType(device.device_type);
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
info.convert_info.engine_settings.device.dla_core = device.dla_core;
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;

info.lower_info.target_device.device_type = toTRTDeviceType(device.device_type);
info.lower_info.target_device.gpu_id = device.gpu_id;
info.lower_info.target_device.dla_core = device.dla_core;
info.lower_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback;

info.partitioning_info.target_device.device_type = toTRTDeviceType(device.device_type);
info.partitioning_info.target_device.gpu_id = device.gpu_id;
info.partitioning_info.target_device.dla_core = device.dla_core;
info.partitioning_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback;

info.partitioning_info.enabled = torch_fallback.enabled;
info.partitioning_info.min_block_size = torch_fallback.min_block_size;
info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;
Expand Down
60 changes: 60 additions & 0 deletions tests/core/partitioning/test_type_auto_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,63 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) {
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
}

TEST(Partitioning, ExplicitNodeAutoInt8ConversionCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%y.1 : Tensor):

%26 : int = prim::Constant[value=1]()
%21 : bool = prim::Constant[value=0]()
%60 : Device = prim::Constant[value="cuda"]()
%14 : NoneType = prim::Constant()
%3 : int = prim::Constant[value=5]()
%19 : int = prim::Constant[value=0]()
%29 : int = prim::Constant[value=2]()
%13 : int[] = prim::ListConstruct(%3, %3)
%k_.1 : Tensor = aten::ones(%13, %19, %14, %60, %14)
%20 : int[] = prim::ListConstruct(%19)
%k.1 : Tensor = aten::sum(%k_.1, %20, %21, %14)
%x.5 : Tensor = aten::add_(%x.1, %y.1, %26)
%31 : Tensor = aten::mul(%y.1, %29)
%x.9 : Tensor = aten::add_(%x.5, %31, %26)
%x.13 : Tensor = aten::add_(%x.9, %k.1, %26)
%x.17 : Tensor = aten::sub_(%x.13, %k.1, %26)
%x.21 : Tensor = aten::add_(%x.17, %k.1, %26)
%x.25 : Tensor = aten::sub_(%x.21, %k.1, %26)

return (%x.25))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.cast_int8_inputs = true;
partitioning_info.forced_fallback_operators = {"aten::ones"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});
inputs_map.insert({g->inputs()[1], {inputs[1]}});
input_types.insert({g->inputs()[1], {{at::kInt}}});

partitioning_info.collection_input_spec_map = inputs_map;
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
ctx.input_types_map = input_types;
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
torch_tensorrt::core::partitioning::partition(&ctx);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}

// Seeking 1 inserted aten::to converting Byte to Int (%k_.1 is a Byte Tensor)
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 1));
}