Skip to content

feat: second attempt to support DDS and NonZero op #3388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 13, 2025
30 changes: 30 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,37 @@ std::vector<std::string> split(const std::string& str, char delim) {
return strings;
}

DynamicOutputAllocator::DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes)
: dtypes(output_dtypes) {}

void* DynamicOutputAllocator::reallocateOutputAsync(
char const* tensorName,
void* currentMemory,
uint64_t size,
uint64_t alignment,
cudaStream_t stream) {
std::vector<int64_t> shape = {static_cast<int64_t>(size)};
auto it = buffers.find(tensorName);
if (it == buffers.end() || it->second.sizes() != shape) {
buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(at::kCUDA));
return buffers[tensorName].data_ptr();
} else {
return it->second.data_ptr();
}
}

void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept {
shapes[tensorName] = dims;
}

TRTEngine::TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata)
: TRTEngine(
"deserialized_trt",
Expand All @@ -46,6 +70,7 @@ TRTEngine::TRTEngine(
_out_binding_names,
target_platform,
hardware_compatible,
requires_output_allocator,
serialized_metadata) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
Expand All @@ -57,6 +82,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
Platform(serialized_info[TARGET_PLATFORM_IDX]),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}

TRTEngine::TRTEngine(
Expand All @@ -67,6 +93,7 @@ TRTEngine::TRTEngine(
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
bool requires_output_allocator,
const std::string& serialized_metadata) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
Expand All @@ -79,6 +106,7 @@ TRTEngine::TRTEngine(
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");

this->serialized_metadata = serialized_metadata;
this->requires_output_allocator = requires_output_allocator;
device_info = most_compatible_device.value();
multi_gpu_device_check();
set_rt_device(device_info);
Expand Down Expand Up @@ -397,6 +425,7 @@ FlattenedState TRTEngine::__obj_flatten__() {
std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]),
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
}

Expand All @@ -417,6 +446,7 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names);
serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names);
serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0";
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();

Expand Down
35 changes: 35 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // input binding names
std::tuple<std::string, std::string>, // output binding names
std::tuple<std::string, std::string>, // HW compatibility
std::tuple<std::string, std::string>, // requires_output_allocator
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

Expand Down Expand Up @@ -69,6 +70,33 @@ struct TorchTRTRuntimeStates {
}
};

class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
public:
DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes);

void* reallocateOutputAsync(
char const* tensorName,
void* currentMemory,
uint64_t size,
uint64_t alignment,
cudaStream_t stream) override;

void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override;

const std::unordered_map<std::string, at::Tensor>& getBuffers() const {
return buffers;
}

const std::unordered_map<std::string, nvinfer1::Dims>& getShapes() const {
return shapes;
}

private:
std::unordered_map<std::string, at::ScalarType> dtypes;
std::unordered_map<std::string, at::Tensor> buffers;
std::unordered_map<std::string, nvinfer1::Dims> shapes;
};

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
Expand Down Expand Up @@ -99,6 +127,7 @@ struct TRTEngine : torch::CustomClassHolder {
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "");

TRTEngine(std::vector<std::string> serialized_info);
Expand All @@ -111,6 +140,7 @@ struct TRTEngine : torch::CustomClassHolder {
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
bool requires_output_allocator = false,
const std::string& serialized_metadata = "");

TRTEngine& operator=(const TRTEngine& other);
Expand Down Expand Up @@ -146,6 +176,11 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;

// Output Allocator-Related Functionality
bool requires_output_allocator = false; // engine requires output allocator
bool use_output_allocator_outputs = false; // users specify to use output allocator
std::shared_ptr<DynamicOutputAllocator> output_allocator;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

Expand Down
Loading