@@ -11,6 +11,8 @@ namespace trtorch {
11
11
namespace core {
12
12
namespace runtime {
13
13
14
+ typedef enum { ABI_TARGET_IDX = 0 , DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15
+
14
16
std::string slugify (std::string s) {
15
17
std::replace (s.begin (), s.end (), ' .' , ' _' );
16
18
return s;
@@ -30,6 +32,12 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
30
32
std::string (" [] = " ),
31
33
util::logging::get_logger().get_reportable_severity(),
32
34
util::logging::get_logger().get_is_colored_output_on()) {
35
+ TRTORCH_CHECK (
36
+ serialized_info.size () == ENGINE_IDX + 1 , " Program to be deserialized targets an incompatible TRTorch ABI" );
37
+ TRTORCH_CHECK (
38
+ serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
39
+ " Program to be deserialized targets a different TRTorch ABI Version ("
40
+ << serialized_info[ABI_TARGET_IDX] << " ) than the TRTorch Runtime ABI (" << ABI_VERSION << " )" );
33
41
std::string _name = " deserialized_trt" ;
34
42
std::string engine_info = serialized_info[ENGINE_IDX];
35
43
@@ -116,6 +124,7 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
116
124
auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
117
125
118
126
std::vector<std::string> serialize_info;
127
+ serialize_info.push_back (ABI_VERSION);
119
128
serialize_info.push_back (serialize_device (self->device_info ));
120
129
serialize_info.push_back (trt_engine);
121
130
return serialize_info;
@@ -124,123 +133,6 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
124
133
return c10::make_intrusive<TRTEngine>(std::move (seralized_info));
125
134
});
126
135
} // namespace
127
- void set_cuda_device (CudaDevice& cuda_device) {
128
- TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) == cudaSuccess), " Unable to set device: " << cuda_device.id );
129
- }
130
-
131
- void get_cuda_device (CudaDevice& cuda_device) {
132
- int device = 0 ;
133
- TRTORCH_CHECK (
134
- (cudaGetDevice (reinterpret_cast <int *>(&device)) == cudaSuccess),
135
- " Unable to get current device: " << cuda_device.id );
136
- cuda_device.id = static_cast <int64_t >(device);
137
- cudaDeviceProp device_prop;
138
- TRTORCH_CHECK (
139
- (cudaGetDeviceProperties (&device_prop, cuda_device.id ) == cudaSuccess),
140
- " Unable to get CUDA properties from device:" << cuda_device.id );
141
- cuda_device.set_major (device_prop.major );
142
- cuda_device.set_minor (device_prop.minor );
143
- std::string device_name (device_prop.name );
144
- cuda_device.set_device_name (device_name);
145
- }
146
-
147
- std::string serialize_device (CudaDevice& cuda_device) {
148
- void * buffer = new char [sizeof (cuda_device)];
149
- void * ref_buf = buffer;
150
-
151
- int64_t temp = cuda_device.get_id ();
152
- memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
153
- buffer = static_cast <char *>(buffer) + sizeof (int64_t );
154
-
155
- temp = cuda_device.get_major ();
156
- memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
157
- buffer = static_cast <char *>(buffer) + sizeof (int64_t );
158
-
159
- temp = cuda_device.get_minor ();
160
- memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
161
- buffer = static_cast <char *>(buffer) + sizeof (int64_t );
162
-
163
- auto device_type = cuda_device.get_device_type ();
164
- memcpy (buffer, reinterpret_cast <char *>(&device_type), sizeof (nvinfer1::DeviceType));
165
- buffer = static_cast <char *>(buffer) + sizeof (nvinfer1::DeviceType);
166
-
167
- size_t device_name_len = cuda_device.get_device_name_len ();
168
- memcpy (buffer, reinterpret_cast <char *>(&device_name_len), sizeof (size_t ));
169
- buffer = static_cast <char *>(buffer) + sizeof (size_t );
170
-
171
- auto device_name = cuda_device.get_device_name ();
172
- memcpy (buffer, reinterpret_cast <char *>(&device_name), device_name.size ());
173
- buffer = static_cast <char *>(buffer) + device_name.size ();
174
-
175
- return std::string ((const char *)ref_buf, sizeof (int64_t ) * 3 + sizeof (nvinfer1::DeviceType) + device_name.size ());
176
- }
177
-
178
- CudaDevice deserialize_device (std::string device_info) {
179
- CudaDevice ret;
180
- char * buffer = new char [device_info.size () + 1 ];
181
- std::copy (device_info.begin (), device_info.end (), buffer);
182
- int64_t temp = 0 ;
183
-
184
- memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
185
- buffer += sizeof (int64_t );
186
- ret.set_id (temp);
187
-
188
- memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
189
- buffer += sizeof (int64_t );
190
- ret.set_major (temp);
191
-
192
- memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
193
- buffer += sizeof (int64_t );
194
- ret.set_minor (temp);
195
-
196
- nvinfer1::DeviceType device_type;
197
- memcpy (&device_type, reinterpret_cast <char *>(buffer), sizeof (nvinfer1::DeviceType));
198
- buffer += sizeof (nvinfer1::DeviceType);
199
-
200
- size_t size;
201
- memcpy (&size, reinterpret_cast <size_t *>(&buffer), sizeof (size_t ));
202
- buffer += sizeof (size_t );
203
-
204
- ret.set_device_name_len (size);
205
-
206
- std::string device_name;
207
- memcpy (&device_name, reinterpret_cast <char *>(buffer), size * sizeof (char ));
208
- buffer += size * sizeof (char );
209
-
210
- ret.set_device_name (device_name);
211
-
212
- return ret;
213
- }
214
-
215
- CudaDevice get_device_info (int64_t gpu_id, nvinfer1::DeviceType device_type) {
216
- CudaDevice cuda_device;
217
- cudaDeviceProp device_prop;
218
-
219
- // Device ID
220
- cuda_device.set_id (gpu_id);
221
-
222
- // Get Device Properties
223
- cudaGetDeviceProperties (&device_prop, gpu_id);
224
-
225
- // Compute capability major version
226
- cuda_device.set_major (device_prop.major );
227
-
228
- // Compute capability minor version
229
- cuda_device.set_minor (device_prop.minor );
230
-
231
- std::string device_name (device_prop.name );
232
-
233
- // Set Device name
234
- cuda_device.set_device_name (device_name);
235
-
236
- // Set Device name len for serialization/deserialization
237
- cuda_device.set_device_name_len (device_name.size ());
238
-
239
- // Set Device Type
240
- cuda_device.set_device_type (device_type);
241
-
242
- return cuda_device;
243
- }
244
136
245
137
} // namespace runtime
246
138
} // namespace core
0 commit comments