Skip to content

Commit 9327cce

Browse files
committed
feat(serde)!: Refactor CudaDevice struct, implement ABI versioning,
serde cleanup BREAKING CHANGE: This commit cleans up the WIP CudaDevice class, simplifying implementation and formalizing the seralized format for CUDA devices. It also implements ABI Versioning. The first entry in the serialized format of a TRTEngine now records the ABI that the engine was compiled with, defining expected compatibility with the TRTorch runtime. If the ABI version does not match, the runtime will error out asking to recompile the program. ABI version is a monotonically increasing integer and should be incremented everytime the serialization format changes in some way. This commit cleans up the CudaDevice class, implementing a number of constructors to replace the various utility functions that populate the struct. Descriptive utility functions remain but solely call the relevant constructor. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 611f6a1 commit 9327cce

File tree

11 files changed

+270
-237
lines changed

11 files changed

+270
-237
lines changed

Diff for: core/compiler.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
223223
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
224224
auto temp_g = std::make_shared<torch::jit::Graph>();
225225
auto device_spec = convert_cfg.engine_settings.device;
226-
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
226+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
227227
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
228228

229229
seg_block.update_graph(temp_g);
@@ -265,7 +265,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
265265
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
266266
auto new_g = std::make_shared<torch::jit::Graph>();
267267
auto device_spec = cfg.convert_info.engine_settings.device;
268-
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
268+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
269269
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
270270
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
271271
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
@@ -277,9 +277,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
277277
return new_mod;
278278
}
279279

280-
torch::jit::script::Module EmbedEngineInNewModule(
281-
const std::string& engine,
282-
trtorch::core::runtime::CudaDevice cuda_device) {
280+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device) {
283281
std::ostringstream engine_id;
284282
engine_id << reinterpret_cast<const int*>(&engine);
285283
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());

Diff for: core/runtime/BUILD

+3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ config_setting(
1010
cc_library(
1111
name = "runtime",
1212
srcs = [
13+
"CudaDevice.cpp",
14+
"DeviceList.cpp",
1315
"TRTEngine.cpp",
1416
"register_trt_op.cpp",
17+
"runtime.cpp"
1518
],
1619
hdrs = [
1720
"runtime.h",

Diff for: core/runtime/CudaDevice.cpp

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "cuda_runtime.h"
2+
3+
#include "core/runtime/runtime.h"
4+
#include "core/util/prelude.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace runtime {
9+
10+
const std::string DEVICE_INFO_DELIM = "%";
11+
12+
typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex;
13+
14+
CudaDevice::CudaDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}
15+
16+
CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
17+
CudaDevice cuda_device;
18+
cudaDeviceProp device_prop;
19+
20+
// Device ID
21+
this->id = gpu_id;
22+
23+
// Get Device Properties
24+
cudaGetDeviceProperties(&device_prop, gpu_id);
25+
26+
// Compute capability major version
27+
this->major = device_prop.major;
28+
29+
// Compute capability minor version
30+
this->minor = device_prop.minor;
31+
32+
std::string device_name(device_prop.name);
33+
34+
// Set Device name
35+
this->device_name = device_name;
36+
37+
// Set Device name len for serialization/deserialization
38+
this->device_name_len = device_name.size();
39+
40+
// Set Device Type
41+
this->device_type = device_type;
42+
}
43+
44+
// NOTE: Serialization Format for Device Info:
45+
// id%major%minor%(enum)device_type%device_name
46+
47+
CudaDevice::CudaDevice(std::string device_info) {
48+
LOG_DEBUG("Deserializing Device Info: " << device_info);
49+
50+
std::vector<std::string> tokens;
51+
int64_t start = 0;
52+
int64_t end = device_info.find(DEVICE_INFO_DELIM);
53+
54+
while (end != -1) {
55+
tokens.push_back(device_info.substr(start, end - start));
56+
start = end + DEVICE_INFO_DELIM.size();
57+
end = device_info.find(DEVICE_INFO_DELIM, start);
58+
}
59+
tokens.push_back(device_info.substr(start, end - start));
60+
61+
TRTORCH_CHECK(tokens.size() == DEVICE_NAME_IDX + 1, "Unable to deserializable program target device infomation");
62+
63+
id = std::stoi(tokens[ID_IDX]);
64+
major = std::stoi(tokens[SM_MAJOR_IDX]);
65+
minor = std::stoi(tokens[SM_MINOR_IDX]);
66+
device_type = (nvinfer1::DeviceType)(std::stoi(tokens[DEVICE_TYPE_IDX]));
67+
device_name = tokens[DEVICE_NAME_IDX];
68+
69+
LOG_DEBUG("Deserialized Device Info: " << *this);
70+
}
71+
72+
std::string CudaDevice::serialize() {
73+
std::stringstream ss;
74+
// clang-format off
75+
ss << id << DEVICE_INFO_DELIM \
76+
<< major << DEVICE_INFO_DELIM \
77+
<< minor << DEVICE_INFO_DELIM \
78+
<< (int64_t) device_type << DEVICE_INFO_DELIM
79+
<< device_name;
80+
// clang-format on
81+
82+
std::string serialized_device_info = ss.str();
83+
84+
LOG_DEBUG("Serialized Device Info: " << serialized_device_info);
85+
86+
return serialized_device_info;
87+
}
88+
89+
std::string CudaDevice::getSMCapability() const {
90+
std::stringstream ss;
91+
ss << major << "." << minor;
92+
return ss.str();
93+
}
94+
95+
std::ostream& operator<<(std::ostream& os, const CudaDevice& device) {
96+
os << "Device(ID: " << device.id << ", Name: " << device.device_name << ", SM Capability: " << device.major << '.'
97+
<< device.minor << ", Type: " << device.device_type << ')';
98+
return os;
99+
}
100+
101+
} // namespace runtime
102+
} // namespace core
103+
} // namespace trtorch

Diff for: core/runtime/DeviceList.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "cuda_runtime.h"
2+
3+
#include "core/runtime/runtime.h"
4+
#include "core/util/prelude.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace runtime {
9+
10+
DeviceList::DeviceList() {
11+
int num_devices = 0;
12+
auto status = cudaGetDeviceCount(&num_devices);
13+
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
14+
for (int i = 0; i < num_devices; i++) {
15+
device_list[i] = CudaDevice(i, nvinfer1::DeviceType::kGPU);
16+
}
17+
18+
// REVIEW: DO WE CARE ABOUT DLA?
19+
20+
LOG_DEBUG("Runtime:\n Available CUDA Devices: \n" << this->dump_list());
21+
}
22+
23+
void DeviceList::insert(int device_id, CudaDevice cuda_device) {
24+
device_list[device_id] = cuda_device;
25+
}
26+
27+
CudaDevice DeviceList::find(int device_id) {
28+
return device_list[device_id];
29+
}
30+
31+
DeviceList::DeviceMap DeviceList::get_devices() {
32+
return device_list;
33+
}
34+
35+
std::string DeviceList::dump_list() {
36+
std::stringstream ss;
37+
for (auto it = device_list.begin(); it != device_list.end(); ++it) {
38+
ss << " " << it->second << std::endl;
39+
}
40+
return ss.str();
41+
}
42+
43+
} // namespace runtime
44+
} // namespace core
45+
} // namespace trtorch

Diff for: core/runtime/TRTEngine.cpp

+9-117
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace trtorch {
1111
namespace core {
1212
namespace runtime {
1313

14+
typedef enum { ABI_TARGET_IDX = 0, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15+
1416
std::string slugify(std::string s) {
1517
std::replace(s.begin(), s.end(), '.', '_');
1618
return s;
@@ -30,6 +32,12 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3032
std::string("[] = "),
3133
util::logging::get_logger().get_reportable_severity(),
3234
util::logging::get_logger().get_is_colored_output_on()) {
35+
TRTORCH_CHECK(
36+
serialized_info.size() == ENGINE_IDX + 1, "Program to be deserialized targets an incompatible TRTorch ABI");
37+
TRTORCH_CHECK(
38+
serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
39+
"Program to be deserialized targets a different TRTorch ABI Version ("
40+
<< serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI (" << ABI_VERSION << ")");
3341
std::string _name = "deserialized_trt";
3442
std::string engine_info = serialized_info[ENGINE_IDX];
3543

@@ -116,6 +124,7 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
116124
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
117125

118126
std::vector<std::string> serialize_info;
127+
serialize_info.push_back(ABI_VERSION);
119128
serialize_info.push_back(serialize_device(self->device_info));
120129
serialize_info.push_back(trt_engine);
121130
return serialize_info;
@@ -124,123 +133,6 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
124133
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
125134
});
126135
} // namespace
127-
void set_cuda_device(CudaDevice& cuda_device) {
128-
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device.id);
129-
}
130-
131-
void get_cuda_device(CudaDevice& cuda_device) {
132-
int device = 0;
133-
TRTORCH_CHECK(
134-
(cudaGetDevice(reinterpret_cast<int*>(&device)) == cudaSuccess),
135-
"Unable to get current device: " << cuda_device.id);
136-
cuda_device.id = static_cast<int64_t>(device);
137-
cudaDeviceProp device_prop;
138-
TRTORCH_CHECK(
139-
(cudaGetDeviceProperties(&device_prop, cuda_device.id) == cudaSuccess),
140-
"Unable to get CUDA properties from device:" << cuda_device.id);
141-
cuda_device.set_major(device_prop.major);
142-
cuda_device.set_minor(device_prop.minor);
143-
std::string device_name(device_prop.name);
144-
cuda_device.set_device_name(device_name);
145-
}
146-
147-
std::string serialize_device(CudaDevice& cuda_device) {
148-
void* buffer = new char[sizeof(cuda_device)];
149-
void* ref_buf = buffer;
150-
151-
int64_t temp = cuda_device.get_id();
152-
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
153-
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
154-
155-
temp = cuda_device.get_major();
156-
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
157-
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
158-
159-
temp = cuda_device.get_minor();
160-
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
161-
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
162-
163-
auto device_type = cuda_device.get_device_type();
164-
memcpy(buffer, reinterpret_cast<char*>(&device_type), sizeof(nvinfer1::DeviceType));
165-
buffer = static_cast<char*>(buffer) + sizeof(nvinfer1::DeviceType);
166-
167-
size_t device_name_len = cuda_device.get_device_name_len();
168-
memcpy(buffer, reinterpret_cast<char*>(&device_name_len), sizeof(size_t));
169-
buffer = static_cast<char*>(buffer) + sizeof(size_t);
170-
171-
auto device_name = cuda_device.get_device_name();
172-
memcpy(buffer, reinterpret_cast<char*>(&device_name), device_name.size());
173-
buffer = static_cast<char*>(buffer) + device_name.size();
174-
175-
return std::string((const char*)ref_buf, sizeof(int64_t) * 3 + sizeof(nvinfer1::DeviceType) + device_name.size());
176-
}
177-
178-
CudaDevice deserialize_device(std::string device_info) {
179-
CudaDevice ret;
180-
char* buffer = new char[device_info.size() + 1];
181-
std::copy(device_info.begin(), device_info.end(), buffer);
182-
int64_t temp = 0;
183-
184-
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
185-
buffer += sizeof(int64_t);
186-
ret.set_id(temp);
187-
188-
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
189-
buffer += sizeof(int64_t);
190-
ret.set_major(temp);
191-
192-
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
193-
buffer += sizeof(int64_t);
194-
ret.set_minor(temp);
195-
196-
nvinfer1::DeviceType device_type;
197-
memcpy(&device_type, reinterpret_cast<char*>(buffer), sizeof(nvinfer1::DeviceType));
198-
buffer += sizeof(nvinfer1::DeviceType);
199-
200-
size_t size;
201-
memcpy(&size, reinterpret_cast<size_t*>(&buffer), sizeof(size_t));
202-
buffer += sizeof(size_t);
203-
204-
ret.set_device_name_len(size);
205-
206-
std::string device_name;
207-
memcpy(&device_name, reinterpret_cast<char*>(buffer), size * sizeof(char));
208-
buffer += size * sizeof(char);
209-
210-
ret.set_device_name(device_name);
211-
212-
return ret;
213-
}
214-
215-
CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type) {
216-
CudaDevice cuda_device;
217-
cudaDeviceProp device_prop;
218-
219-
// Device ID
220-
cuda_device.set_id(gpu_id);
221-
222-
// Get Device Properties
223-
cudaGetDeviceProperties(&device_prop, gpu_id);
224-
225-
// Compute capability major version
226-
cuda_device.set_major(device_prop.major);
227-
228-
// Compute capability minor version
229-
cuda_device.set_minor(device_prop.minor);
230-
231-
std::string device_name(device_prop.name);
232-
233-
// Set Device name
234-
cuda_device.set_device_name(device_name);
235-
236-
// Set Device name len for serialization/deserialization
237-
cuda_device.set_device_name_len(device_name.size());
238-
239-
// Set Device Type
240-
cuda_device.set_device_type(device_type);
241-
242-
return cuda_device;
243-
}
244136

245137
} // namespace runtime
246138
} // namespace core

0 commit comments

Comments
 (0)