Skip to content

feat: Add support for providing input datatypes in TRTorch #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ More Information / System Architecture:
...
auto compile_settings = trtorch::CompileSpec(dims);
// FP16 execution
compile_settings.op_precision = torch::kFloat;
compile_settings.op_precision = torch::kHalf;
// Set input datatypes. Allowerd options torch::{kFloat, kHalf, kChar, kInt32, kBool}
// Size of input_dtypes should match number of inputs to the network.
// If input_dtypes is not set, default precision for input tensors would be float32
compile_spec.input_dtypes = {torch::kHalf};
// Compile module
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
Expand All @@ -43,7 +47,8 @@ compile_settings = {
"max": [1, 3, 1024, 1024]
}, # For static size [1, 3, 224, 224]
],
"op_precision": torch.half # Run with FP16
"op_precision": torch.half, # Run with FP16
"input_dtypes": [torch.half] # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
}

trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
Expand Down
10 changes: 5 additions & 5 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
LOG_INFO(*g << "(LoweringGraph)\n");

// segment the graph and convert segmented TensorRT block
auto segmented_blocks = partitioning::Partition(g, convert_cfg.input_ranges, cfg.partition_info);
auto segmented_blocks = partitioning::Partition(g, convert_cfg.inputs, cfg.partition_info);
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
LOG_WARNING("Didn't generate any TensorRT engines, the compiler did nothing\n");
return mod;
Expand All @@ -210,16 +210,16 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
for (auto& seg_block : segmented_blocks) {
std::string cur_block_target =
seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
LOG_INFO(*seg_block.g() << "(MiniGraphIn" << cur_block_target << "Block)\n");
LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
std::ostringstream trt_engine_id;
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
std::vector<ir::InputRange> input_ranges;
std::vector<ir::Input> inputs;
for (auto& shape : seg_block.in_shape()) {
input_ranges.push_back(ir::InputRange(shape));
inputs.push_back(ir::Input(shape));
}
// update the input ranges for each segments
convert_cfg.input_ranges = input_ranges;
convert_cfg.inputs = inputs;
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
Expand Down
2 changes: 1 addition & 1 deletion core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace trtorch {
namespace core {

struct CompileSpec {
CompileSpec(std::vector<ir::InputRange> input_ranges) : convert_info(std::move(input_ranges)) {}
CompileSpec(std::vector<ir::Input> inputs) : convert_info(std::move(inputs)) {}
conversion::ConversionInfo convert_info;
partitioning::PartitionInfo partition_info;
};
Expand Down
36 changes: 23 additions & 13 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
void AddInputs(
ConversionCtx* ctx,
at::ArrayRef<const torch::jit::Value*> inputs,
std::vector<ir::InputRange>& input_dims) {
std::vector<ir::Input>& input_specs) {
std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
Expand All @@ -142,29 +142,39 @@ void AddInputs(
}
}

std::stringstream ss;
ss << "Input Dimension Specs: [\n";
for (auto i : input_specs) {
ss << " " << i << ",";
}
ss << ']';
LOG_DEBUG(ss.str());

TRTORCH_CHECK(
input_tensors.size() == input_dims.size(),
input_tensors.size() == input_specs.size(),
"Expected dimension specifications for all input tensors"
<< ", but found " << input_tensors.size() << " input tensors and " << input_dims.size()
<< ", but found " << input_tensors.size() << " input tensors and " << input_specs.size()
<< " dimension specs (conversion.AddInputs)");

auto profile = ctx->builder->createOptimizationProfile();

for (size_t i = 0; i < input_tensors.size(); i++) {
auto in = input_tensors[i];
auto dims = input_dims[i];
auto spec = input_specs[i];
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
ctx->logger, "Adding Input " << in->debugName() << " named " << name << " in engine (conversion.AddInputs)");
LOG_DEBUG(ctx->logger, "Input shape set to " << dims.input_shape);
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_type, dims.input_shape);
ctx->logger,
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec << " in engine (conversion.AddInputs)");

auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));

profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, spec.min);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, spec.opt);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, spec.max);

if (dims.input_is_dynamic) {
if (spec.input_is_dynamic) {
ctx->input_is_dynamic = true;
}

Expand All @@ -178,7 +188,7 @@ void AddInputs(

ctx->cfg->addOptimizationProfile(profile);
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR >= 1)
if (ctx->op_precision == nvinfer1::DataType::kINT8) {
if (ctx->enabled_precisions.find(nvinfer1::DataType::kINT8) != ctx->enabled_precisions.end()) {
ctx->cfg->setCalibrationProfile(profile);
}
#endif
Expand Down Expand Up @@ -350,7 +360,7 @@ void ConvertBlockToNetDef(

auto inputs = b->inputs();
AddParamsToCtxValueMap(ctx, static_params);
AddInputs(ctx, inputs, build_info.input_ranges);
AddInputs(ctx, inputs, build_info.inputs);

auto nodes = b->nodes();

Expand Down
6 changes: 3 additions & 3 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ namespace core {
namespace conversion {

struct ConversionInfo {
std::vector<ir::InputRange> input_ranges;
std::vector<ir::Input> inputs;
BuilderSettings engine_settings;
ConversionInfo(std::vector<ir::InputRange> input_ranges)
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
ConversionInfo(std::vector<ir::Input> inputs)
: inputs(std::move(inputs)), engine_settings(BuilderSettings()) {}
};

// TODO: REMOVE GRAPH AND PARAMS AND MOVE FULLY TO INLINED CONSTANTS
Expand Down
53 changes: 28 additions & 25 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ namespace conversion {
// clang-format off
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
os << "Settings requested for TensorRT engine:" \
<< "\n Operating Precision: " << s.op_precision \
<< "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
<< "\n Enabled Precisions: ";
for (auto p = s.enabled_precisions.begin(); p != s.enabled_precisions.end(); ++p) {
os << *p << ' ';
}
os << "\n TF32 Floating Point Computation Enabled: " << !s.disable_tf32 \
<< "\n Truncate Long and Double: " << s.truncate_long_and_double \
<< "\n Make Refittable Engine: " << s.refit \
<< "\n Debuggable Engine: " << s.debug \
Expand Down Expand Up @@ -57,30 +60,30 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
LOG_DEBUG(build_settings);
cfg = builder->createBuilderConfig();

switch (settings.op_precision) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
input_type = nvinfer1::DataType::kHALF;
break;
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
if (!settings.strict_types) {
for(auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
switch (*p) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(
settings.calibrator != nullptr,
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
default:
input_type = nvinfer1::DataType::kFLOAT;
break;
break;
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
TRTORCH_CHECK(
settings.calibrator != nullptr,
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
case nvinfer1::DataType::kINT32:
case nvinfer1::DataType::kBOOL:
default:
break;
}
}
op_precision = settings.op_precision;

enabled_precisions = settings.enabled_precisions;
input_dtypes = settings.input_dtypes;

if (settings.disable_tf32) {
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
Expand Down Expand Up @@ -118,7 +121,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
static_cast<int>(settings.device.dla_core) < nbDLACores,
"Configured DLA Core ID: " << settings.device.dla_core
<< " not available. Total number of available DLA Cores: " << nbDLACores);
TRTORCH_CHECK(settings.op_precision != nvinfer1::DataType::kFLOAT, "DLA supports only fp16 or int8 precision");
TRTORCH_CHECK(settings.enabled_precisions.find(nvinfer1::DataType::kFLOAT) == settings.enabled_precisions.end(), "DLA supports only fp16 or int8 precision");
cfg->setDLACore(settings.device.dla_core);
}
}
Expand Down
8 changes: 5 additions & 3 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <map>
#include <memory>
#include <unordered_map>
#include <set>

#include "NvInfer.h"
#include "torch/csrc/jit/ir/ir.h"
Expand All @@ -23,7 +24,8 @@ struct Device {
};

struct BuilderSettings {
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
std::vector<nvinfer1::DataType> input_dtypes;
bool disable_tf32 = false;
bool refit = false;
bool debug = false;
Expand Down Expand Up @@ -57,8 +59,8 @@ struct ConversionCtx {
nvinfer1::IBuilder* builder;
nvinfer1::INetworkDefinition* net;
nvinfer1::IBuilderConfig* cfg;
nvinfer1::DataType input_type;
nvinfer1::DataType op_precision;
std::vector<nvinfer1::DataType> input_dtypes;
std::set<nvinfer1::DataType> enabled_precisions;
BuilderSettings settings;
util::logging::TRTorchLogger logger;
// Pointers to data that needs to remain alive until conversion is done
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ auto acthardtanh TRTORCH_UNUSED =
std::string pluginName = "CustomGeluPluginDynamic";
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
int type_id = ctx->settings.op_precision == nvinfer1::DataType::kFLOAT
//REVIEW is this right?
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) == ctx->settings.enabled_precisions.end()
? 0
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
Expand Down
2 changes: 1 addition & 1 deletion core/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cc_library(
"ir.h"
],
srcs = [
"InputRange.cpp",
"Input.cpp"
],
deps = [
"@tensorrt//:nvinfer",
Expand Down
Loading