Skip to content

Implementation of the PyTorch Backend API #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ build --cxxopt='-std=c++14'
build:python --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build:python --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build:python --define=abi=pre_cxx11_abi
build:python --define=target_lang=python

build:pre_cxx11_abi --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build:pre_cxx11_abi --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
Expand Down
6 changes: 3 additions & 3 deletions .github/pr-labels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"component: evaluators":
- core/conversion/evaluators/**/*

"component: execution":
- core/execution/**/*
"component: runtime":
- core/runtime/**/*

"component: lowering":
- core/lowering/**/*
Expand All @@ -32,4 +32,4 @@
"documentation":
- docs/**/*
- docsrc/**/*

2 changes: 1 addition & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pkg_tar(
"//core/conversion/tensorcontainer:include",
"//core/conversion/evaluators:include",
"//core/conversion/converters/impl/plugins:include",
"//core/execution:include",
"//core/runtime:include",
"//core/lowering:include",
"//core/lowering/passes:include",
"//core/util:include",
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ More Information / System Architecture:
#include "trtorch/trtorch.h"

...
auto compile_settings = trtorch::ExtraInfo(dims);
auto compile_settings = trtorch::CompileSpec(dims);
// FP16 execution
compile_settings.op_precision = torch::kFloat;
// Compile module
Expand Down Expand Up @@ -54,7 +54,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
```

> Notes on running in lower precisions:
> - Set precision with extra_info.op_precision
> - Set precision with compile_spec.op_precision
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32

Expand Down
15 changes: 12 additions & 3 deletions core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ config_setting(
}
)

config_setting(
name = "python_core",
values = {
"define": "target_lang=python"
}
)

cc_library(
name = "core",
hdrs = [
Expand All @@ -17,7 +24,7 @@ cc_library(
],
deps = [
"//core/conversion",
"//core/execution",
"//core/runtime",
"//core/lowering",
"//core/util/logging",
"@tensorrt//:nvinfer"
Expand All @@ -28,11 +35,13 @@ cc_library(
alwayslink=True,
)


load("@rules_pkg//:pkg.bzl", "pkg_tar")

pkg_tar(
name = "include",
package_dir = "core/",
srcs = ["compiler.h"],
srcs = [
"backend.h",
"compiler.h",
],
)
12 changes: 6 additions & 6 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "core/lowering/lowering.h"
#include "core/conversion/conversion.h"
#include "core/execution/execution.h"
#include "core/runtime/runtime.h"

namespace trtorch {
namespace core {
Expand All @@ -42,15 +42,15 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str


void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine);
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine_ptr->num_io;
auto name = engine_ptr->name;

// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
mod.register_attribute(
name,
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
c10::getCustomClassType<c10::intrusive_ptr<runtime::TRTEngine>>(),
c10::IValue(std::move(engine_ptr)),
false
);
Expand Down Expand Up @@ -125,7 +125,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name,
ExtraInfo cfg) {
CompileSpec cfg) {

// Go through Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = lowering::Lower(mod, method_name);
Expand All @@ -137,12 +137,12 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,

LOG_INFO(*g << "(CompileGraph)\n");

auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
return std::move(engine);
}

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
ExtraInfo cfg) {
CompileSpec cfg) {
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
//torch::jit::script::Module new_mod = mod.clone();
Expand Down
8 changes: 4 additions & 4 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
namespace trtorch {
namespace core {

struct ExtraInfo {
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
struct CompileSpec {
CompileSpec(std::vector<conversion::InputRange> input_ranges)
: convert_info(std::move(input_ranges)) {}
conversion::ConversionInfo convert_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name, ExtraInfo cfg);
std::string method_name, CompileSpec cfg);

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);

} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
Expand Down
8 changes: 4 additions & 4 deletions core/execution/BUILD → core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ config_setting(
)

cc_library(
name = "execution",
name = "runtime",
hdrs = [
"execution.h",
"runtime.h",
],
srcs = [
"TRTEngine.cpp",
Expand All @@ -30,6 +30,6 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")

pkg_tar(
name = "include",
package_dir = "core/execution/",
srcs = ["execution.h"],
package_dir = "core/runtime/",
srcs = ["runtime.h"],
)
9 changes: 5 additions & 4 deletions core/execution/TRTEngine.cpp → core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "torch/csrc/jit/frontend/function_schema_parser.h"

#include "core/util/prelude.h"
#include "core/execution/execution.h"
#include "core/runtime/runtime.h"

namespace trtorch {
namespace core {
namespace execution {
namespace runtime {

std::string slugify(std::string s) {
std::replace(s.begin(), s.end(), '.', '_');
Expand Down Expand Up @@ -81,6 +81,7 @@ TRTEngine::~TRTEngine() {
// return c10::List<at::Tensor>(output_vec);
// }

namespace {
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("tensorrt", "Engine")
.def(torch::init<std::string>())
// TODO: .def("__call__", &TRTEngine::Run)
Expand All @@ -94,7 +95,7 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("te
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
}
);

} // namespace execution
} // namespace
} // namespace runtime
} // namespace core
} // namespace trtorch
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "torch/csrc/jit/runtime/custom_operator.h"

#include "core/util/prelude.h"
#include "core/execution/execution.h"
#include "core/runtime/runtime.h"

namespace trtorch {
namespace core {
namespace execution {
namespace runtime {

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
Expand All @@ -30,7 +30,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
gpu_handles.push_back(contig_inputs.back().data_ptr());
}

TRTORCH_CHECK(compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (execution.RunCudaEngine)");
TRTORCH_CHECK(compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");

std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
for (size_t o = inputs.size(); o < (compiled_engine->num_io.first + compiled_engine->num_io.second); o++) {
Expand All @@ -53,6 +53,6 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine", execute_engine);
}

} // namespace execution
} // namespace runtime
} // namespace core
} // namespace trtorch
4 changes: 2 additions & 2 deletions core/execution/execution.h → core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace trtorch {
namespace core {
namespace execution {
namespace runtime {

using EngineID = int64_t;

Expand All @@ -35,6 +35,6 @@ struct TRTEngine : torch::CustomClassHolder {

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

} // namespace execution
} // namespace runtime
} // namespace core
} // namespace trtorch
2 changes: 1 addition & 1 deletion cpp/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cc_library(
"include/trtorch/ptq.h"
],
srcs = [
"src/extra_info.cpp",
"src/compile_spec.cpp",
"src/logging.cpp",
"src/trtorch.cpp",
"src/ptq.cpp"
Expand Down
16 changes: 8 additions & 8 deletions cpp/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace trtorch {
* Settings data structure for TRTorch compilation
*
*/
struct TRTORCH_API ExtraInfo {
struct TRTORCH_API CompileSpec {
/**
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
*
Expand Down Expand Up @@ -132,10 +132,10 @@ struct TRTORCH_API ExtraInfo {
kSAFE_DLA,
};

ExtraInfo(std::vector<InputRange> input_ranges)
CompileSpec(std::vector<InputRange> input_ranges)
: input_ranges(std::move(input_ranges)) {}
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);

// Defaults should reflect TensorRT defaults for BuilderConfig

Expand Down Expand Up @@ -236,27 +236,27 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& mo
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
*
* @param module: torch::jit::script::Module - Existing TorchScript module
* @param info: trtorch::ExtraInfo - Compilation settings
* @param info: trtorch::CompileSpec - Compilation settings
*
* Takes a existing TorchScript module and a set of settings to configure the compiler
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
*
* Converts specifically the forward method of a TorchScript Module
*/
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info);
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec info);

/**
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
*
* @param module: torch::jit::script::Module - Existing TorchScript module
* @param method_name: std::string - Name of method to compile
* @param info: trtorch::ExtraInfo - Compilation settings
* @param info: trtorch::CompileSpec - Compilation settings
*
* Takes a existing TorchScript module and a set of settings to configure the compiler
* and will convert selected method to a serialized TensorRT engine which can be run with
* TensorRT
*/
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, CompileSpec info);

namespace ptq {
/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/api/include/trtorch/ptq.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Int8Calibrator : Algorithm {
/**
* @brief operator to cast to nvinfer1::IInt8Calibrator*
*
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
*
* @return nvinfer1::IInt8Calibrator*
*/
Expand Down Expand Up @@ -259,7 +259,7 @@ class Int8CacheCalibrator : Algorithm {
/**
* @brief operator to cast to nvinfer1::IInt8Calibrator*
*
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
*
* @return nvinfer1::IInt8Calibrator*
*/
Expand Down
Loading