Skip to content

Commit 7e1a1ca

Browse files
committed
support C++ runtime and add tests
1 parent 107599a commit 7e1a1ca

File tree

9 files changed

+585
-389
lines changed

9 files changed

+585
-389
lines changed

core/runtime/TRTEngine.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,37 @@ std::vector<std::string> split(const std::string& str, char delim) {
3030
return strings;
3131
}
3232

33+
DynamicOutputAllocator::DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes)
34+
: dtypes(output_dtypes) {}
35+
36+
void* DynamicOutputAllocator::reallocateOutputAsync(
37+
char const* tensorName,
38+
void* currentMemory,
39+
uint64_t size,
40+
uint64_t alignment,
41+
cudaStream_t stream) {
42+
std::vector<int64_t> shape = {static_cast<int64_t>(size)};
43+
auto it = buffers.find(tensorName);
44+
if (it == buffers.end() || it->second.sizes() != shape) {
45+
buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(c10::kCUDA));
46+
return buffers[tensorName].data_ptr();
47+
} else {
48+
return it->second.data_ptr();
49+
}
50+
}
51+
52+
void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept {
53+
shapes[tensorName] = dims;
54+
}
55+
3356
TRTEngine::TRTEngine(
3457
const std::string& serialized_engine,
3558
const RTDevice& cuda_device,
3659
const std::vector<std::string>& _in_binding_names,
3760
const std::vector<std::string>& _out_binding_names,
3861
const Platform& target_platform,
3962
bool hardware_compatible,
63+
bool requires_output_allocator,
4064
const std::string& serialized_metadata)
4165
: TRTEngine(
4266
"deserialized_trt",
@@ -46,6 +70,7 @@ TRTEngine::TRTEngine(
4670
_out_binding_names,
4771
target_platform,
4872
hardware_compatible,
73+
requires_output_allocator,
4974
serialized_metadata) {}
5075

5176
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
@@ -57,6 +82,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
5782
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
5883
Platform(serialized_info[TARGET_PLATFORM_IDX]),
5984
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
85+
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
6086
serialized_info[SERIALIZED_METADATA_IDX]) {}
6187

6288
TRTEngine::TRTEngine(
@@ -67,6 +93,7 @@ TRTEngine::TRTEngine(
6793
const std::vector<std::string>& _out_binding_names,
6894
const Platform& target_platform,
6995
bool hardware_compatible,
96+
bool requires_output_allocator,
7097
const std::string& serialized_metadata) {
7198
TORCHTRT_CHECK(
7299
is_supported_on_current_platform(target_platform),
@@ -79,6 +106,7 @@ TRTEngine::TRTEngine(
79106
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
80107

81108
this->serialized_metadata = serialized_metadata;
109+
this->requires_output_allocator = requires_output_allocator;
82110
device_info = most_compatible_device.value();
83111
multi_gpu_device_check();
84112
set_rt_device(device_info);
@@ -397,6 +425,7 @@ FlattenedState TRTEngine::__obj_flatten__() {
397425
std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]),
398426
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
399427
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
428+
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
400429
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
401430
}
402431

@@ -417,6 +446,7 @@ std::vector<std::string> TRTEngine::serialize() {
417446
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names);
418447
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names);
419448
serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0";
449+
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
420450
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
421451
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
422452

core/runtime/TRTEngine.h

+35
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ using FlattenedState = std::tuple<
2727
std::tuple<std::string, std::string>, // input binding names
2828
std::tuple<std::string, std::string>, // output binding names
2929
std::tuple<std::string, std::string>, // HW compatibility
30+
std::tuple<std::string, std::string>, // requires_output_allocator
3031
std::tuple<std::string, std::string>, // serialized metadata
3132
std::tuple<std::string, std::string>>; // Platform
3233

@@ -69,6 +70,33 @@ struct TorchTRTRuntimeStates {
6970
}
7071
};
7172

73+
class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
74+
public:
75+
DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes);
76+
77+
void* reallocateOutputAsync(
78+
char const* tensorName,
79+
void* currentMemory,
80+
uint64_t size,
81+
uint64_t alignment,
82+
cudaStream_t stream) override;
83+
84+
void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override;
85+
86+
const std::unordered_map<std::string, at::Tensor>& getBuffers() const {
87+
return buffers;
88+
}
89+
90+
const std::unordered_map<std::string, nvinfer1::Dims>& getShapes() const {
91+
return shapes;
92+
}
93+
94+
private:
95+
std::unordered_map<std::string, at::ScalarType> dtypes;
96+
std::unordered_map<std::string, at::Tensor> buffers;
97+
std::unordered_map<std::string, nvinfer1::Dims> shapes;
98+
};
99+
72100
struct TRTEngine : torch::CustomClassHolder {
73101
// Each engine needs it's own runtime object
74102
std::shared_ptr<nvinfer1::IRuntime> rt;
@@ -99,6 +127,7 @@ struct TRTEngine : torch::CustomClassHolder {
99127
const std::vector<std::string>& out_binding_names,
100128
const Platform& target_platform = get_current_platform(),
101129
bool hardware_compatible = false,
130+
bool requires_output_allocator = false,
102131
const std::string& serialized_metadata = "");
103132

104133
TRTEngine(std::vector<std::string> serialized_info);
@@ -111,6 +140,7 @@ struct TRTEngine : torch::CustomClassHolder {
111140
const std::vector<std::string>& out_binding_names,
112141
const Platform& target_platform = get_current_platform(),
113142
bool hardware_compatible = false,
143+
bool requires_output_allocator = false,
114144
const std::string& serialized_metadata = "");
115145

116146
TRTEngine& operator=(const TRTEngine& other);
@@ -146,6 +176,11 @@ struct TRTEngine : torch::CustomClassHolder {
146176
bool use_pre_allocated_outputs = false;
147177
std::vector<at::Tensor> pre_allocated_outputs;
148178

179+
// Output Allocator-Related Functionality
180+
bool requires_output_allocator = false; // engine requires output allocator
181+
bool use_output_allocator_outputs = false; // users specify to use output allocator
182+
std::shared_ptr<DynamicOutputAllocator> output_allocator;
183+
149184
// TODO: Implement a call method
150185
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
151186

0 commit comments

Comments
 (0)