Skip to content

Commit 6eb3bb2

Browse files
committed
feat(//core/runtime)!: Better and more portable names for engines
BREAKING CHANGE: This bumps the TRTorch ABI version to 3 due to a new field for engine name included in the serialized form of TRTEngine. This lets deserialized engines have the same name they serialized with Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent c54ed13 commit 6eb3bb2

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

Diff for: core/compiler.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void AddEngineToGraph(
3636
std::string engine_id = "",
3737
bool fallback = false) {
3838
auto engine_ptr =
39-
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
39+
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + "_engine_" + engine_id, serialized_engine, device_info);
4040
// Get required metadata about the engine out
4141
auto num_io = engine_ptr->num_io;
4242
auto name = engine_ptr->name;

Diff for: core/runtime/TRTEngine.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace trtorch {
1111
namespace core {
1212
namespace runtime {
1313

14-
typedef enum { ABI_TARGET_IDX = 0, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
14+
typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
1515

1616
std::string slugify(std::string s) {
1717
std::replace(s.begin(), s.end(), '.', '_');
@@ -37,8 +37,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3737
TRTORCH_CHECK(
3838
serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
3939
"Program to be deserialized targets a different TRTorch ABI Version ("
40-
<< serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI (" << ABI_VERSION << ")");
41-
std::string _name = "deserialized_trt";
40+
<< serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI Version (" << ABI_VERSION << ")");
41+
std::string _name = serialized_info[NAME_IDX];
4242
std::string engine_info = serialized_info[ENGINE_IDX];
4343

4444
CudaDevice cuda_device = deserialize_device(serialized_info[DEVICE_IDX]);
@@ -55,7 +55,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
5555

5656
rt = nvinfer1::createInferRuntime(logger);
5757

58-
name = slugify(mod_name) + "_engine";
58+
name = slugify(mod_name);
5959

6060
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
6161
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");
@@ -70,8 +70,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
7070
uint64_t outputs = 0;
7171

7272
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
73-
std::string name = cuda_engine->getBindingName(x);
74-
std::string idx_s = name.substr(name.find("_") + 1);
73+
std::string bind_name = cuda_engine->getBindingName(x);
74+
std::string idx_s = bind_name.substr(bind_name.find("_") + 1);
7575
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
7676

7777
if (cuda_engine->bindingIsInput(x)) {
@@ -124,9 +124,12 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
124124
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
125125

126126
std::vector<std::string> serialize_info;
127-
serialize_info.push_back(ABI_VERSION);
128-
serialize_info.push_back(serialize_device(self->device_info));
129-
serialize_info.push_back(trt_engine);
127+
serialize_info.resize(ENGINE_IDX + 1);
128+
129+
serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
130+
serialize_info[NAME_IDX] = self->name;
131+
serialize_info[DEVICE_IDX] = serialize_device(self->device_info);
132+
serialize_info[ENGINE_IDX] = trt_engine;
130133
return serialize_info;
131134
},
132135
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {

Diff for: core/runtime/runtime.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace core {
1111
namespace runtime {
1212

1313
using EngineID = int64_t;
14-
const std::string ABI_VERSION = "2";
14+
const std::string ABI_VERSION = "3";
1515

1616
struct CudaDevice {
1717
int64_t id; // CUDA device id

0 commit comments

Comments
 (0)