diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 13cbe3a126..92e5d7a8ff 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -32,8 +32,15 @@ TRTEngine::TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, - const std::vector& _out_binding_names) - : TRTEngine("deserialized_trt", serialized_engine, cuda_device, _in_binding_names, _out_binding_names) {} + const std::vector& _out_binding_names, + bool hardware_compatible) + : TRTEngine( + "deserialized_trt", + serialized_engine, + cuda_device, + _in_binding_names, + _out_binding_names, + hardware_compatible) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -41,15 +48,19 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[ENGINE_IDX], RTDevice(serialized_info[DEVICE_IDX]), split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM), - split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM)) {} + split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), + static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {} TRTEngine::TRTEngine( const std::string& mod_name, const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, - const std::vector& _out_binding_names) { - auto most_compatible_device = get_most_compatible_device(cuda_device); + const std::vector& _out_binding_names, + bool hardware_compatible) { + this->hardware_compatible = hardware_compatible; + + auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); device_info = most_compatible_device.value(); multi_gpu_device_check(); @@ -232,6 +243,7 @@ std::string TRTEngine::to_str() const { } ss << " }" << std::endl; ss << " Device: " << device_info << std::endl; + ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; // clang-format on return ss.str(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 3f165ee0c0..3d52aa2689 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -34,19 +34,23 @@ struct TRTEngine : torch::CustomClassHolder { std::vector in_binding_names = {}; // ITO: PYT IDX std::vector out_binding_names = {}; // ITO: PYT IDX + bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode + ~TRTEngine(); TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, - const std::vector& out_binding_names); + const std::vector& out_binding_names, + bool hardware_compatible = false); TRTEngine(std::vector serialized_info); TRTEngine( const std::string& mod_name, const std::string& serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, - const std::vector& out_binding_names); + const std::vector& out_binding_names, + bool hardware_compatible = false); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; static void verify_serialization_fmt(const std::vector& serialized_info); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 5551010a2a..5ff163fbfb 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -43,8 +43,8 @@ bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_devi return false; } -RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device) { - auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device); +RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device, bool hardware_compatible) { + auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device, hardware_compatible); // REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD // TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right @@ -59,7 +59,9 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de } std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { - LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")"); + LOG_DEBUG( + "Attempting to run engine (ID: " << compiled_engine->name + << "); Hardware Compatible: " << compiled_engine->hardware_compatible); if (compiled_engine->profile_execution) { std::stringstream ss; @@ -89,7 +91,8 @@ std::vector execute_engine(std::vector inputs, c10::intr if (is_switch_required(curr_device, compiled_engine->device_info)) { // Scan through available CUDA devices and set the CUDA device context correctly - RTDevice device = select_rt_device(compiled_engine->device_info, curr_device); + RTDevice device = + select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); set_rt_device(device); // Target device is new device diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 5ad0efb3b0..4ae4f92337 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -101,6 +101,9 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = serialize_info[ENGINE_IDX] = base64_encode(trt_engine); serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names); serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names); + serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0"; + + LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled")); return serialize_info; }, diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index 2d7f7f1198..afe559ec4d 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -9,9 +9,12 @@ namespace runtime { bool MULTI_DEVICE_SAFE_MODE = false; -c10::optional get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) { +c10::optional get_most_compatible_device( + const RTDevice& target_device, + const RTDevice& curr_device, + bool hardware_compatible) { LOG_DEBUG("Target Device: " << target_device); - auto device_options = find_compatible_devices(target_device); + auto device_options = find_compatible_devices(target_device, hardware_compatible); RTDevice current_device; if (current_device.id == -1) { current_device = get_current_device(); @@ -30,7 +33,8 @@ c10::optional get_most_compatible_device(const RTDevice& target_device dev_list << "[" << std::endl; for (auto device : device_options) { dev_list << " " << device << ',' << std::endl; - if (device.device_name == target_device.device_name) { + // If the model is hardware compatible, any compatible device should be valid + if ((device.device_name == target_device.device_name) || hardware_compatible) { // First priority is selecting a candidate which agrees with the current device ID // If such a device is found, we can select it and break out of the loop if (device.id == current_device.id) { @@ -60,7 +64,7 @@ c10::optional get_most_compatible_device(const RTDevice& target_device } } -std::vector find_compatible_devices(const RTDevice& target_device) { +std::vector find_compatible_devices(const RTDevice& target_device, bool hardware_compatible) { auto dla_supported = get_dla_supported_SMs(); auto device_list = get_available_device_list().get_devices(); @@ -76,7 +80,8 @@ std::vector find_compatible_devices(const RTDevice& target_device) { } else if (target_device.device_type == nvinfer1::DeviceType::kGPU) { auto target_dev_cc = target_device.getSMCapability(); // If the SM Capabilities match, should be good enough to run - if (poss_dev_cc == target_dev_cc) { + // If hardware compatibility mode is enabled and the SM is at least 80, device is valid + if ((poss_dev_cc == target_dev_cc) || (hardware_compatible && std::stoi(poss_dev_cc) >= 8)) { compatible_devices.push_back(device.second); } } else { diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index ea863850ba..5e5676b11e 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -15,7 +15,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "4"; +const std::string ABI_VERSION = "5"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { ABI_TARGET_IDX = 0, @@ -24,13 +24,15 @@ typedef enum { ENGINE_IDX, INPUT_BINDING_NAMES_IDX, OUTPUT_BINDING_NAMES_IDX, + HW_COMPATIBLE_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; c10::optional get_most_compatible_device( const RTDevice& target_device, - const RTDevice& curr_device = RTDevice()); -std::vector find_compatible_devices(const RTDevice& target_device); + const RTDevice& curr_device = RTDevice(), + bool hardware_compatible = false); +std::vector find_compatible_devices(const RTDevice& target_device, bool hardware_compatible); std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine); diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ac7a323545..5edb7cc32a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -22,6 +22,7 @@ DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENGINE_CAPABILITY, + HARDWARE_COMPATIBLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, @@ -94,6 +95,7 @@ def compile( use_fast_partitioner: bool = USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, dryrun: bool = DRYRUN, + hardware_compatible: bool = HARDWARE_COMPATIBLE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -151,6 +153,7 @@ def compile( use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs + hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -227,6 +230,7 @@ def compile( "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, "dryrun": dryrun, + "hardware_compatible": hardware_compatible, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 4afabe60eb..3d48ab3def 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -25,6 +25,7 @@ REFIT = False REQUIRE_FULL_COMPILATION = False DRYRUN = False +HARDWARE_COMPATIBLE = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 60990bda99..2992496665 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -13,6 +13,7 @@ DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENGINE_CAPABILITY, + HARDWARE_COMPATIBLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, @@ -67,6 +68,7 @@ class CompilationSettings: dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the ouptut to a file if a string path is specified + hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) """ precision: torch.dtype = PRECISION @@ -93,3 +95,4 @@ class CompilationSettings: dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE dryrun: Union[bool, str] = DRYRUN + hardware_compatible: bool = HARDWARE_COMPATIBLE diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index eec7e62516..5db9fc183e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -188,6 +188,11 @@ def run( if self.compilation_settings.version_compatible: _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if self.compilation_settings.hardware_compatible: + _LOGGER.info("Using hardware compatible") + builder_config.hardware_compatibility_level = ( + trt.HardwareCompatibilityLevel.AMPERE_PLUS + ) if self.compilation_settings.optimization_level is not None: _LOGGER.info( f"Using optimization level {self.compilation_settings.optimization_level}" diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e4f0df5818..9359796711 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -76,4 +76,5 @@ def convert_module( input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), target_device=settings.device, + hardware_compatible=settings.hardware_compatible, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d5ad5021e2..709c10b36e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -9,10 +9,10 @@ logger = logging.getLogger(__name__) SerializedTensorRTEngineFmt = Tuple[ - str, str, bytes, str, str + str, str, str, bytes, str, str, str ] # Defined in //core/runtime/register_jit_hooks.cpp SerializedTorchTensorRTModuleFmt = Tuple[ - str, SerializedTensorRTEngineFmt, List[str], List[str] + str, Optional[SerializedTensorRTEngineFmt], List[str], List[str] ] @@ -43,6 +43,7 @@ def __init__( input_binding_names: Optional[List[str]] = None, output_binding_names: Optional[List[str]] = None, target_device: Device = Device._current_device(), + hardware_compatible: bool = False, ): """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule @@ -89,6 +90,7 @@ def __init__( output_binding_names if output_binding_names is not None else [] ) self.name = name + self.hardware_compatible = hardware_compatible if serialized_engine is not None: self.engine = torch.classes.tensorrt.Engine( @@ -99,6 +101,7 @@ def __init__( serialized_engine, TorchTensorRTModule._pack_binding_names(self.input_binding_names), TorchTensorRTModule._pack_binding_names(self.output_binding_names), + str(int(hardware_compatible)), ] ) else: @@ -115,7 +118,7 @@ def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.name = state[0] if state[1] is not None: - serialized_engine_info = state[1][0] + serialized_engine_info: SerializedTensorRTEngineFmt = state[1] import base64 serialized_engine = base64.b64decode(serialized_engine_info[3]) @@ -127,6 +130,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: serialized_engine, serialized_engine_info[4], serialized_engine_info[5], + serialized_engine_info[6], ] ) else: @@ -134,6 +138,9 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.input_binding_names = state[2] self.output_binding_names = state[3] + self.hardware_compatible = ( + bool(int(state[1][6])) if state[1] is not None else False + ) def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py new file mode 100644 index 0000000000..9ee7206adf --- /dev/null +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -0,0 +1,77 @@ +import os +import unittest + +import torch +from torch.testing._internal.common_utils import TestCase, run_tests + +import torch_tensorrt + + +class TestHardwareCompatibility(TestCase): + def test_hw_compat_enabled(self): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.softmax((x * 7) @ x.T, dim=0) + + inputs = [torch.randn(5, 7).cuda()] + + # Validate that the hardware compatibility mode has been enabled + optimized_model_hw_compat = torch_tensorrt.compile( + SampleModel(), + "dynamo", + inputs, + min_block_size=1, + pass_through_build_failures=True, + hardware_compatible=True, + use_python_runtime=False, + ) + + self.assertTrue(optimized_model_hw_compat._run_on_acc_0.hardware_compatible) + + cpp_repr = optimized_model_hw_compat._run_on_acc_0.engine._properties.__repr__() + + self.assertIn("Hardware Compatibility: Enabled", cpp_repr) + + # Validate that the hardware compatibility mode has been disabled + optimized_model_not_hw_compat = torch_tensorrt.compile( + SampleModel(), + "dynamo", + inputs, + min_block_size=1, + pass_through_build_failures=True, + hardware_compatible=False, + use_python_runtime=False, + ) + + self.assertFalse( + optimized_model_not_hw_compat._run_on_acc_0.hardware_compatible + ) + + cpp_repr = ( + optimized_model_not_hw_compat._run_on_acc_0.engine._properties.__repr__() + ) + + self.assertIn("Hardware Compatibility: Disabled", cpp_repr) + + @unittest.skipIf( + torch.ops.tensorrt.ABI_VERSION() != "5", + "Detected incorrect ABI version, please update this test case", + ) + def test_hw_compat_3080_build(self): + inputs = [torch.randn(5, 7).cuda()] + + cwd = os.getcwd() + os.chdir(os.path.dirname(os.path.realpath(__file__))) + model = torch.jit.load("../../ts/models/hw_compat.ts").cuda() + out = model(*inputs) + self.assertTrue( + isinstance(out, tuple) + and len(out) == 1 + and isinstance(out[0], torch.Tensor), + "Invalid output detected", + ) + os.chdir(cwd) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/ts/models/hw_compat.ts b/tests/py/ts/models/hw_compat.ts new file mode 100644 index 0000000000..ab43e5e040 Binary files /dev/null and b/tests/py/ts/models/hw_compat.ts differ