@@ -30,13 +30,37 @@ std::vector<std::string> split(const std::string& str, char delim) {
30
30
return strings;
31
31
}
32
32
33
+ DynamicOutputAllocator::DynamicOutputAllocator (const std::unordered_map<std::string, at::ScalarType>& output_dtypes)
34
+ : dtypes(output_dtypes) {}
35
+
36
+ void * DynamicOutputAllocator::reallocateOutputAsync (
37
+ char const * tensorName,
38
+ void * currentMemory,
39
+ uint64_t size,
40
+ uint64_t alignment,
41
+ cudaStream_t stream) {
42
+ std::vector<int64_t > shape = {static_cast <int64_t >(size)};
43
+ auto it = buffers.find (tensorName);
44
+ if (it == buffers.end () || it->second .sizes () != shape) {
45
+ buffers[tensorName] = at::empty (shape, at::TensorOptions ().dtype (dtypes.at (tensorName)).device (c10::kCUDA ));
46
+ return buffers[tensorName].data_ptr ();
47
+ } else {
48
+ return it->second .data_ptr ();
49
+ }
50
+ }
51
+
52
+ void DynamicOutputAllocator::notifyShape (char const * tensorName, nvinfer1::Dims const & dims) noexcept {
53
+ shapes[tensorName] = dims;
54
+ }
55
+
33
56
TRTEngine::TRTEngine (
34
57
const std::string& serialized_engine,
35
58
const RTDevice& cuda_device,
36
59
const std::vector<std::string>& _in_binding_names,
37
60
const std::vector<std::string>& _out_binding_names,
38
61
const Platform& target_platform,
39
62
bool hardware_compatible,
63
+ bool requires_output_allocator,
40
64
const std::string& serialized_metadata)
41
65
: TRTEngine(
42
66
" deserialized_trt" ,
@@ -46,6 +70,7 @@ TRTEngine::TRTEngine(
46
70
_out_binding_names,
47
71
target_platform,
48
72
hardware_compatible,
73
+ requires_output_allocator,
49
74
serialized_metadata) {}
50
75
51
76
TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
@@ -57,6 +82,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
57
82
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
58
83
Platform(serialized_info[TARGET_PLATFORM_IDX]),
59
84
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
85
+ static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
60
86
serialized_info[SERIALIZED_METADATA_IDX]) {}
61
87
62
88
TRTEngine::TRTEngine (
@@ -67,6 +93,7 @@ TRTEngine::TRTEngine(
67
93
const std::vector<std::string>& _out_binding_names,
68
94
const Platform& target_platform,
69
95
bool hardware_compatible,
96
+ bool requires_output_allocator,
70
97
const std::string& serialized_metadata) {
71
98
TORCHTRT_CHECK (
72
99
is_supported_on_current_platform (target_platform),
@@ -79,6 +106,7 @@ TRTEngine::TRTEngine(
79
106
TORCHTRT_CHECK (most_compatible_device, " No compatible device was found for instantiating TensorRT engine" );
80
107
81
108
this ->serialized_metadata = serialized_metadata;
109
+ this ->requires_output_allocator = requires_output_allocator;
82
110
device_info = most_compatible_device.value ();
83
111
multi_gpu_device_check ();
84
112
set_rt_device (device_info);
@@ -397,6 +425,7 @@ FlattenedState TRTEngine::__obj_flatten__() {
397
425
std::tuple (" out_binding_names" , serialized_info[OUTPUT_BINDING_NAMES_IDX]),
398
426
std::tuple (" hardware_compatible" , serialized_info[HW_COMPATIBLE_IDX]),
399
427
std::tuple (" serialized_metadata" , serialized_info[SERIALIZED_METADATA_IDX]),
428
+ std::tuple (" requires_output_allocator" , serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
400
429
std::tuple (" target_platform" , serialized_info[TARGET_PLATFORM_IDX]));
401
430
}
402
431
@@ -417,6 +446,7 @@ std::vector<std::string> TRTEngine::serialize() {
417
446
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings (this ->in_binding_names );
418
447
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings (this ->out_binding_names );
419
448
serialized_info[HW_COMPATIBLE_IDX] = this ->hardware_compatible ? " 1" : " 0" ;
449
+ serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this ->requires_output_allocator ? " 1" : " 0" ;
420
450
serialized_info[SERIALIZED_METADATA_IDX] = this ->serialized_metadata ;
421
451
serialized_info[TARGET_PLATFORM_IDX] = this ->target_platform .serialize ();
422
452
0 commit comments