@@ -11,7 +11,7 @@ namespace trtorch {
11
11
namespace core {
12
12
namespace runtime {
13
13
14
- typedef enum { ABI_TARGET_IDX = 0 , DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
14
+ typedef enum { ABI_TARGET_IDX = 0 , NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15
15
16
16
std::string slugify (std::string s) {
17
17
std::replace (s.begin (), s.end (), ' .' , ' _' );
@@ -37,8 +37,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
37
37
TRTORCH_CHECK (
38
38
serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
39
39
" Program to be deserialized targets a different TRTorch ABI Version ("
40
- << serialized_info[ABI_TARGET_IDX] << " ) than the TRTorch Runtime ABI (" << ABI_VERSION << " )" );
41
- std::string _name = " deserialized_trt " ;
40
+ << serialized_info[ABI_TARGET_IDX] << " ) than the TRTorch Runtime ABI Version (" << ABI_VERSION << " )" );
41
+ std::string _name = serialized_info[NAME_IDX] ;
42
42
std::string engine_info = serialized_info[ENGINE_IDX];
43
43
44
44
CudaDevice cuda_device = deserialize_device (serialized_info[DEVICE_IDX]);
@@ -55,7 +55,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
55
55
56
56
rt = nvinfer1::createInferRuntime (logger);
57
57
58
- name = slugify (mod_name) + " _engine " ;
58
+ name = slugify (mod_name);
59
59
60
60
cuda_engine = rt->deserializeCudaEngine (serialized_engine.c_str (), serialized_engine.size ());
61
61
TRTORCH_CHECK ((cuda_engine != nullptr ), " Unable to deserialize the TensorRT engine" );
@@ -70,8 +70,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
70
70
uint64_t outputs = 0 ;
71
71
72
72
for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x++) {
73
- std::string name = cuda_engine->getBindingName (x);
74
- std::string idx_s = name .substr (name .find (" _" ) + 1 );
73
+ std::string bind_name = cuda_engine->getBindingName (x);
74
+ std::string idx_s = bind_name .substr (bind_name .find (" _" ) + 1 );
75
75
uint64_t idx = static_cast <uint64_t >(std::stoi (idx_s));
76
76
77
77
if (cuda_engine->bindingIsInput (x)) {
@@ -124,9 +124,12 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
124
124
auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
125
125
126
126
std::vector<std::string> serialize_info;
127
- serialize_info.push_back (ABI_VERSION);
128
- serialize_info.push_back (serialize_device (self->device_info ));
129
- serialize_info.push_back (trt_engine);
127
+ serialize_info.resize (ENGINE_IDX + 1 );
128
+
129
+ serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
130
+ serialize_info[NAME_IDX] = self->name ;
131
+ serialize_info[DEVICE_IDX] = serialize_device (self->device_info );
132
+ serialize_info[ENGINE_IDX] = trt_engine;
130
133
return serialize_info;
131
134
},
132
135
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
0 commit comments