Skip to content

feat: Add hardware compatibility option in Dynamo #2445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,35 @@ TRTEngine::TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names)
: TRTEngine("deserialized_trt", serialized_engine, cuda_device, _in_binding_names, _out_binding_names) {}
const std::vector<std::string>& _out_binding_names,
bool hardware_compatible)
: TRTEngine(
"deserialized_trt",
serialized_engine,
cuda_device,
_in_binding_names,
_out_binding_names,
hardware_compatible) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
serialized_info[NAME_IDX],
serialized_info[ENGINE_IDX],
RTDevice(serialized_info[DEVICE_IDX]),
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM)) {}
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {}

TRTEngine::TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names) {
auto most_compatible_device = get_most_compatible_device(cuda_device);
const std::vector<std::string>& _out_binding_names,
bool hardware_compatible) {
this->hardware_compatible = hardware_compatible;

auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
device_info = most_compatible_device.value();
multi_gpu_device_check();
Expand Down Expand Up @@ -232,6 +243,7 @@ std::string TRTEngine::to_str() const {
}
ss << " }" << std::endl;
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
// clang-format on
return ss.str();
}
Expand Down
8 changes: 6 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,23 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<std::string> in_binding_names = {}; // ITO: PYT IDX
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX

bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode

~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names);
const std::vector<std::string>& out_binding_names,
bool hardware_compatible = false);
TRTEngine(std::vector<std::string> serialized_info);
TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names);
const std::vector<std::string>& out_binding_names,
bool hardware_compatible = false);
TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand Down
11 changes: 7 additions & 4 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_devi
return false;
}

RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device) {
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device);
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device, bool hardware_compatible) {
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device, hardware_compatible);

// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
// TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
Expand All @@ -59,7 +59,9 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
LOG_DEBUG(
"Attempting to run engine (ID: " << compiled_engine->name
<< "); Hardware Compatible: " << compiled_engine->hardware_compatible);

if (compiled_engine->profile_execution) {
std::stringstream ss;
Expand Down Expand Up @@ -89,7 +91,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

if (is_switch_required(curr_device, compiled_engine->device_info)) {
// Scan through available CUDA devices and set the CUDA device context correctly
RTDevice device = select_rt_device(compiled_engine->device_info, curr_device);
RTDevice device =
select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible);
set_rt_device(device);

// Target device is new device
Expand Down
3 changes: 3 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";

LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));

return serialize_info;
},
Expand Down
15 changes: 10 additions & 5 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;

c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device,
bool hardware_compatible) {
LOG_DEBUG("Target Device: " << target_device);
auto device_options = find_compatible_devices(target_device);
auto device_options = find_compatible_devices(target_device, hardware_compatible);
RTDevice current_device;
if (current_device.id == -1) {
current_device = get_current_device();
Expand All @@ -30,7 +33,8 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
dev_list << "[" << std::endl;
for (auto device : device_options) {
dev_list << " " << device << ',' << std::endl;
if (device.device_name == target_device.device_name) {
// If the model is hardware compatible, any compatible device should be valid
if ((device.device_name == target_device.device_name) || hardware_compatible) {
// First priority is selecting a candidate which agrees with the current device ID
// If such a device is found, we can select it and break out of the loop
if (device.id == current_device.id) {
Expand Down Expand Up @@ -60,7 +64,7 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
}
}

std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device) {
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible) {
auto dla_supported = get_dla_supported_SMs();
auto device_list = get_available_device_list().get_devices();

Expand All @@ -76,7 +80,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device) {
} else if (target_device.device_type == nvinfer1::DeviceType::kGPU) {
auto target_dev_cc = target_device.getSMCapability();
// If the SM Capabilities match, should be good enough to run
if (poss_dev_cc == target_dev_cc) {
// If hardware compatibility mode is enabled and the SM is at least 80, device is valid
if ((poss_dev_cc == target_dev_cc) || (hardware_compatible && std::stoi(poss_dev_cc) >= 8)) {
compatible_devices.push_back(device.second);
}
} else {
Expand Down
8 changes: 5 additions & 3 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace core {
namespace runtime {

using EngineID = int64_t;
const std::string ABI_VERSION = "4";
const std::string ABI_VERSION = "5";
extern bool MULTI_DEVICE_SAFE_MODE;
typedef enum {
ABI_TARGET_IDX = 0,
Expand All @@ -24,13 +24,15 @@ typedef enum {
ENGINE_IDX,
INPUT_BINDING_NAMES_IDX,
OUTPUT_BINDING_NAMES_IDX,
HW_COMPATIBLE_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice());
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
const RTDevice& curr_device = RTDevice(),
bool hardware_compatible = false);
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible);

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);

Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DRYRUN,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
Expand Down Expand Up @@ -94,6 +95,7 @@ def compile(
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = DRYRUN,
hardware_compatible: bool = HARDWARE_COMPATIBLE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -151,6 +153,7 @@ def compile(
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -227,6 +230,7 @@ def compile(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
REFIT = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DRYRUN,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
Expand Down Expand Up @@ -67,6 +68,7 @@ class CompilationSettings:
dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
ouptut to a file if a string path is specified
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
"""

precision: torch.dtype = PRECISION
Expand All @@ -93,3 +95,4 @@ class CompilationSettings:
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
dryrun: Union[bool, str] = DRYRUN
hardware_compatible: bool = HARDWARE_COMPATIBLE
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def run(
if self.compilation_settings.version_compatible:
_LOGGER.info("Using version compatible")
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
if self.compilation_settings.hardware_compatible:
_LOGGER.info("Using hardware compatible")
builder_config.hardware_compatibility_level = (
trt.HardwareCompatibilityLevel.AMPERE_PLUS
)
if self.compilation_settings.optimization_level is not None:
_LOGGER.info(
f"Using optimization level {self.compilation_settings.optimization_level}"
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ def convert_module(
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
target_device=settings.device,
hardware_compatible=settings.hardware_compatible,
)
13 changes: 10 additions & 3 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
logger = logging.getLogger(__name__)

SerializedTensorRTEngineFmt = Tuple[
str, str, bytes, str, str
str, str, str, bytes, str, str, str
] # Defined in //core/runtime/register_jit_hooks.cpp
SerializedTorchTensorRTModuleFmt = Tuple[
str, SerializedTensorRTEngineFmt, List[str], List[str]
str, Optional[SerializedTensorRTEngineFmt], List[str], List[str]
]


Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
input_binding_names: Optional[List[str]] = None,
output_binding_names: Optional[List[str]] = None,
target_device: Device = Device._current_device(),
hardware_compatible: bool = False,
):
"""__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule

Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
output_binding_names if output_binding_names is not None else []
)
self.name = name
self.hardware_compatible = hardware_compatible

if serialized_engine is not None:
self.engine = torch.classes.tensorrt.Engine(
Expand All @@ -99,6 +101,7 @@ def __init__(
serialized_engine,
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
str(int(hardware_compatible)),
]
)
else:
Expand All @@ -115,7 +118,7 @@ def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
self.name = state[0]
if state[1] is not None:
serialized_engine_info = state[1][0]
serialized_engine_info: SerializedTensorRTEngineFmt = state[1]
import base64

serialized_engine = base64.b64decode(serialized_engine_info[3])
Expand All @@ -127,13 +130,17 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
serialized_engine,
serialized_engine_info[4],
serialized_engine_info[5],
serialized_engine_info[6],
]
)
else:
self.engine = None

self.input_binding_names = state[2]
self.output_binding_names = state[3]
self.hardware_compatible = (
bool(int(state[1][6])) if state[1] is not None else False
)

def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
"""Implementation of the forward pass for a TensorRT engine
Expand Down
Loading