Skip to content

Commit b8fa228

Browse files
committed
refactor!: Renaming extra info to compile spec to be more consistent
with other backends and between APIs in TRTorch BREAKING CHANGE: This changes the top level api for setting the specification for compilation, a simple find and replace should allow users to port forward Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e4a4574 commit b8fa228

27 files changed

+194
-187
lines changed

Diff for: .bazelrc

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ build --cxxopt='-std=c++14'
2929
build:python --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3030
build:python --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3131
build:python --define=abi=pre_cxx11_abi
32+
build:python --define=target_lang=python
3233

3334
build:pre_cxx11_abi --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
3435
build:pre_cxx11_abi --linkopt="-D_GLIBCXX_USE_CXX11_ABI=0"

Diff for: README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ More Information / System Architecture:
1818
#include "trtorch/trtorch.h"
1919

2020
...
21-
auto compile_settings = trtorch::ExtraInfo(dims);
21+
auto compile_settings = trtorch::CompileSpec(dims);
2222
// FP16 execution
2323
compile_settings.op_precision = torch::kFloat;
2424
// Compile module
@@ -54,7 +54,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
5454
```
5555

5656
> Notes on running in lower precisions:
57-
> - Set precision with extra_info.op_precision
57+
> - Set precision with compile_spec.op_precision
5858
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
5959
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
6060

Diff for: core/compiler.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#include "core/lowering/lowering.h"
2222
#include "core/conversion/conversion.h"
23-
#include "core/execution/execution.h"
23+
#include "core/runtime/runtime.h"
2424

2525
namespace trtorch {
2626
namespace core {
@@ -42,15 +42,15 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
4242

4343

4444
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
45-
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine);
45+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
4646
// Get required metadata about the engine out
4747
auto num_io = engine_ptr->num_io;
4848
auto name = engine_ptr->name;
4949

5050
// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
5151
mod.register_attribute(
5252
name,
53-
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
53+
c10::getCustomClassType<c10::intrusive_ptr<runtime::TRTEngine>>(),
5454
c10::IValue(std::move(engine_ptr)),
5555
false
5656
);
@@ -125,7 +125,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
125125

126126
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
127127
std::string method_name,
128-
ExtraInfo cfg) {
128+
CompileSpec cfg) {
129129

130130
// Go through Lowering to simplify graph and extract weight parameters
131131
auto graph_and_parameters = lowering::Lower(mod, method_name);
@@ -137,12 +137,12 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
137137

138138
LOG_INFO(*g << "(CompileGraph)\n");
139139

140-
auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
140+
auto engine = conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
141141
return std::move(engine);
142142
}
143143

144144
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
145-
ExtraInfo cfg) {
145+
CompileSpec cfg) {
146146
// TODO: Should be doing a functional transform but need PR #31978
147147
// [jit] More robust mangling
148148
//torch::jit::script::Module new_mod = mod.clone();

Diff for: core/compiler.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
namespace trtorch {
88
namespace core {
99

10-
struct ExtraInfo {
11-
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
10+
struct CompileSpec {
11+
CompileSpec(std::vector<conversion::InputRange> input_ranges)
1212
: convert_info(std::move(input_ranges)) {}
1313
conversion::ConversionInfo convert_info;
1414
};
1515

1616
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);
1717

1818
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
19-
std::string method_name, ExtraInfo cfg);
19+
std::string method_name, CompileSpec cfg);
2020

21-
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
21+
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2222

2323
} // namespace core
2424
} // namespace trtorch

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5555
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
5656
}
5757
input_type = nvinfer1::DataType::kFLOAT;
58-
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
58+
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
5959
cfg->setInt8Calibrator(settings.calibrator);
6060
break;
6161
case nvinfer1::DataType::kFLOAT:

Diff for: cpp/api/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cc_library(
99
"include/trtorch/ptq.h"
1010
],
1111
srcs = [
12-
"src/extra_info.cpp",
12+
"src/compile_spec.cpp",
1313
"src/logging.cpp",
1414
"src/trtorch.cpp",
1515
"src/ptq.cpp"

Diff for: cpp/api/README.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace trtorch {
3131
* Settings data structure for TRTorch compilation
3232
*
3333
*/
34-
struct TRTORCH_API ExtraInfo {
34+
struct TRTORCH_API CompileSpec {
3535
/**
3636
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
3737
*
@@ -132,10 +132,10 @@ struct TRTORCH_API ExtraInfo {
132132
kSAFE_DLA,
133133
};
134134

135-
ExtraInfo(std::vector<InputRange> input_ranges)
135+
CompileSpec(std::vector<InputRange> input_ranges)
136136
: input_ranges(std::move(input_ranges)) {}
137-
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
138-
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
137+
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
138+
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
139139

140140
// Defaults should reflect TensorRT defaults for BuilderConfig
141141

@@ -236,27 +236,27 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& mo
236236
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
237237
*
238238
* @param module: torch::jit::script::Module - Existing TorchScript module
239-
* @param info: trtorch::ExtraInfo - Compilation settings
239+
* @param info: trtorch::CompileSpec - Compilation settings
240240
*
241241
* Takes a existing TorchScript module and a set of settings to configure the compiler
242242
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
243243
*
244244
* Converts specifically the forward method of a TorchScript Module
245245
*/
246-
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info);
246+
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec info);
247247

248248
/**
249249
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
250250
*
251251
* @param module: torch::jit::script::Module - Existing TorchScript module
252252
* @param method_name: std::string - Name of method to compile
253-
* @param info: trtorch::ExtraInfo - Compilation settings
253+
* @param info: trtorch::CompileSpec - Compilation settings
254254
*
255255
* Takes a existing TorchScript module and a set of settings to configure the compiler
256256
* and will convert selected method to a serialized TensorRT engine which can be run with
257257
* TensorRT
258258
*/
259-
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
259+
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, CompileSpec info);
260260

261261
namespace ptq {
262262
/**

Diff for: cpp/api/include/trtorch/ptq.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class Int8Calibrator : Algorithm {
145145
/**
146146
* @brief operator to cast to nvinfer1::IInt8Calibrator*
147147
*
148-
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
148+
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
149149
*
150150
* @return nvinfer1::IInt8Calibrator*
151151
*/
@@ -259,7 +259,7 @@ class Int8CacheCalibrator : Algorithm {
259259
/**
260260
* @brief operator to cast to nvinfer1::IInt8Calibrator*
261261
*
262-
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in ExtraInfo
262+
* Convience function to convert to a IInt8Calibrator* to easily be assigned to the ptq_calibrator field in CompileSpec
263263
*
264264
* @return nvinfer1::IInt8Calibrator*
265265
*/

Diff for: cpp/api/include/trtorch/trtorch.h

+8-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace trtorch {
3939
* Settings data structure for TRTorch compilation
4040
*
4141
*/
42-
struct TRTORCH_API ExtraInfo {
42+
struct TRTORCH_API CompileSpec {
4343
/**
4444
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
4545
*
@@ -256,7 +256,7 @@ struct TRTORCH_API ExtraInfo {
256256
*
257257
* @param input_ranges
258258
*/
259-
ExtraInfo(std::vector<InputRange> input_ranges)
259+
CompileSpec(std::vector<InputRange> input_ranges)
260260
: input_ranges(std::move(input_ranges)) {}
261261
/**
262262
* @brief Construct a new Extra Info object
@@ -265,14 +265,14 @@ struct TRTORCH_API ExtraInfo {
265265
*
266266
* @param fixed_sizes
267267
*/
268-
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
268+
CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes);
269269
/**
270270
* @brief Construct a new Extra Info object
271271
* Convienence constructor to set fixed input size from c10::ArrayRef's (the output of tensor.sizes()) describing size of input tensors.
272272
* Each entry in the vector represents a input and should be provided in call order.
273273
* @param fixed_sizes
274274
*/
275-
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
275+
CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
276276

277277
// Defaults should reflect TensorRT defaults for BuilderConfig
278278

@@ -379,7 +379,7 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, st
379379
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
380380
*
381381
* @param module: torch::jit::Module - Existing TorchScript module
382-
* @param info: trtorch::ExtraInfo - Compilation settings
382+
* @param info: trtorch::CompileSpec - Compilation settings
383383
*
384384
* Takes a existing TorchScript module and a set of settings to configure the compiler
385385
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
@@ -388,20 +388,20 @@ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::Module& module, st
388388
*
389389
* @return: A new module trageting a TensorRT engine
390390
*/
391-
TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, ExtraInfo info);
391+
TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, CompileSpec info);
392392

393393
/**
394394
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
395395
*
396396
* @param module: torch::jit::Module - Existing TorchScript module
397397
* @param method_name: std::string - Name of method to compile
398-
* @param info: trtorch::ExtraInfo - Compilation settings
398+
* @param info: trtorch::CompileSpec - Compilation settings
399399
*
400400
* Takes a existing TorchScript module and a set of settings to configure the compiler
401401
* and will convert selected method to a serialized TensorRT engine which can be run with
402402
* TensorRT
403403
*
404404
* @return: std::string: Serialized TensorRT engine equivilant to the method graph
405405
*/
406-
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);
406+
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, CompileSpec info);
407407
} // namespace trtorch

Diff for: cpp/api/src/extra_info.cpp renamed to cpp/api/src/compile_spec.cpp

+20-20
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "trtorch/trtorch.h"
77

88
namespace trtorch {
9-
ExtraInfo::DataType::DataType(c10::ScalarType t) {
9+
CompileSpec::DataType::DataType(c10::ScalarType t) {
1010
TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported");
1111
switch (t) {
1212
case at::kHalf:
@@ -21,70 +21,70 @@ ExtraInfo::DataType::DataType(c10::ScalarType t) {
2121
}
2222
}
2323

24-
ExtraInfo::DeviceType::DeviceType(c10::DeviceType t) {
24+
CompileSpec::DeviceType::DeviceType(c10::DeviceType t) {
2525
TRTORCH_CHECK(t == at::kCUDA, "Device type when specified using torch device enum must be torch::kCUDA");
2626
value = DeviceType::kGPU;
2727
}
2828

29-
ExtraInfo::InputRange::InputRange(std::vector<int64_t> opt) {
29+
CompileSpec::InputRange::InputRange(std::vector<int64_t> opt) {
3030
this->opt = opt;
3131
this->min = opt;
3232
this->max = opt;
3333
}
3434

35-
ExtraInfo::InputRange::InputRange(c10::IntArrayRef opt) {
35+
CompileSpec::InputRange::InputRange(c10::IntArrayRef opt) {
3636
this->opt = core::util::toVec(opt);
3737
this->min = core::util::toVec(opt);
3838
this->max = core::util::toVec(opt);
3939
}
4040

41-
ExtraInfo::InputRange::InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
41+
CompileSpec::InputRange::InputRange(std::vector<int64_t> min, std::vector<int64_t> opt, std::vector<int64_t> max) {
4242
this->opt = opt;
4343
this->min = min;
4444
this->max = max;
4545
}
4646

47-
ExtraInfo::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
47+
CompileSpec::InputRange::InputRange(c10::IntArrayRef min, c10::IntArrayRef opt, c10::IntArrayRef max) {
4848
this->opt = core::util::toVec(opt);
4949
this->min = core::util::toVec(min);
5050
this->max = core::util::toVec(max);
5151
}
5252

53-
ExtraInfo::ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
53+
CompileSpec::CompileSpec(std::vector<c10::ArrayRef<int64_t>> fixed_sizes) {
5454
for (auto in : fixed_sizes) {
5555
input_ranges.push_back(InputRange(in));
5656
}
5757
}
5858

59-
ExtraInfo::ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes) {
59+
CompileSpec::CompileSpec(std::vector<std::vector<int64_t>> fixed_sizes) {
6060
for (auto in : fixed_sizes) {
6161
input_ranges.push_back(InputRange(in));
6262
}
6363
}
6464

65-
core::conversion::InputRange to_internal_input_range(ExtraInfo::InputRange i) {
65+
core::conversion::InputRange to_internal_input_range(CompileSpec::InputRange i) {
6666
return core::conversion::InputRange(i.min, i.opt, i.max);
6767
}
6868

69-
std::vector<core::conversion::InputRange> to_vec_internal_input_ranges(std::vector<ExtraInfo::InputRange> external) {
69+
std::vector<core::conversion::InputRange> to_vec_internal_input_ranges(std::vector<CompileSpec::InputRange> external) {
7070
std::vector<core::conversion::InputRange> internal;
7171
for (auto range : external) {
7272
internal.push_back(to_internal_input_range(range));
7373
}
7474
return internal;
7575
}
7676

77-
core::ExtraInfo to_internal_extra_info(ExtraInfo external) {
78-
core::ExtraInfo internal(to_vec_internal_input_ranges(external.input_ranges));
77+
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
78+
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));
7979

8080
switch(external.op_precision) {
81-
case ExtraInfo::DataType::kChar:
81+
case CompileSpec::DataType::kChar:
8282
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kINT8;
8383
break;
84-
case ExtraInfo::DataType::kHalf:
84+
case CompileSpec::DataType::kHalf:
8585
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kHALF;
8686
break;
87-
case ExtraInfo::DataType::kFloat:
87+
case CompileSpec::DataType::kFloat:
8888
default:
8989
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
9090
}
@@ -96,22 +96,22 @@ core::ExtraInfo to_internal_extra_info(ExtraInfo external) {
9696
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
9797

9898
switch(external.device) {
99-
case ExtraInfo::DeviceType::kDLA:
99+
case CompileSpec::DeviceType::kDLA:
100100
internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kDLA;
101101
break;
102-
case ExtraInfo::DeviceType::kGPU:
102+
case CompileSpec::DeviceType::kGPU:
103103
default:
104104
internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kGPU;
105105
}
106106

107107
switch(external.capability) {
108-
case ExtraInfo::EngineCapability::kSAFE_GPU:
108+
case CompileSpec::EngineCapability::kSAFE_GPU:
109109
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU;
110110
break;
111-
case ExtraInfo::EngineCapability::kSAFE_DLA:
111+
case CompileSpec::EngineCapability::kSAFE_DLA:
112112
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA;
113113
break;
114-
case ExtraInfo::EngineCapability::kDEFAULT:
114+
case CompileSpec::EngineCapability::kDEFAULT:
115115
default:
116116
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT;
117117

0 commit comments

Comments
 (0)