Skip to content

Commit 437151a

Browse files
committed
feat: Add hardware compatibility option in Dynamo
- Add support for hardware compatibility for Ampere and later architectures - Add necessary functions to support the modification throughout the stack, including C++ and Python components - Update ABI version to address new metadata format for TRT Engines - Update engine serialization schema accordingly - Add test cases to validate feature
1 parent 1ff10a6 commit 437151a

File tree

14 files changed

+173
-45
lines changed

14 files changed

+173
-45
lines changed

core/runtime/TRTEngine.cpp

+17-5
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,35 @@ TRTEngine::TRTEngine(
3232
const std::string& serialized_engine,
3333
const RTDevice& cuda_device,
3434
const std::vector<std::string>& _in_binding_names,
35-
const std::vector<std::string>& _out_binding_names)
36-
: TRTEngine("deserialized_trt", serialized_engine, cuda_device, _in_binding_names, _out_binding_names) {}
35+
const std::vector<std::string>& _out_binding_names,
36+
bool hardware_compatible)
37+
: TRTEngine(
38+
"deserialized_trt",
39+
serialized_engine,
40+
cuda_device,
41+
_in_binding_names,
42+
_out_binding_names,
43+
hardware_compatible) {}
3744

3845
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3946
: TRTEngine(
4047
serialized_info[NAME_IDX],
4148
serialized_info[ENGINE_IDX],
4249
RTDevice(serialized_info[DEVICE_IDX]),
4350
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
44-
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM)) {}
51+
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
52+
static_cast<bool>(std::stoi(serialized_info[HARDWARE_COMPATIBLE]))) {}
4553

4654
TRTEngine::TRTEngine(
4755
const std::string& mod_name,
4856
const std::string& serialized_engine,
4957
const RTDevice& cuda_device,
5058
const std::vector<std::string>& _in_binding_names,
51-
const std::vector<std::string>& _out_binding_names) {
52-
auto most_compatible_device = get_most_compatible_device(cuda_device);
59+
const std::vector<std::string>& _out_binding_names,
60+
bool hardware_compatible) {
61+
this->hardware_compatible = hardware_compatible;
62+
63+
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
5364
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
5465
device_info = most_compatible_device.value();
5566
set_rt_device(device_info);
@@ -231,6 +242,7 @@ std::string TRTEngine::to_str() const {
231242
}
232243
ss << " }" << std::endl;
233244
ss << " Device: " << device_info << std::endl;
245+
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
234246
// clang-format on
235247
return ss.str();
236248
}

core/runtime/TRTEngine.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,23 @@ struct TRTEngine : torch::CustomClassHolder {
3434
std::vector<std::string> in_binding_names = {}; // ITO: PYT IDX
3535
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX
3636

37+
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
38+
3739
~TRTEngine();
3840
TRTEngine(
3941
const std::string& serialized_engine,
4042
const RTDevice& cuda_device,
4143
const std::vector<std::string>& in_binding_names,
42-
const std::vector<std::string>& out_binding_names);
44+
const std::vector<std::string>& out_binding_names,
45+
bool hardware_compatible = false);
4346
TRTEngine(std::vector<std::string> serialized_info);
4447
TRTEngine(
4548
const std::string& mod_name,
4649
const std::string& serialized_engine,
4750
const RTDevice& cuda_device,
4851
const std::vector<std::string>& in_binding_names,
49-
const std::vector<std::string>& out_binding_names);
52+
const std::vector<std::string>& out_binding_names,
53+
bool hardware_compatible = false);
5054
TRTEngine& operator=(const TRTEngine& other);
5155
std::string to_str() const;
5256
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);

core/runtime/execute_engine.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_devi
4343
return false;
4444
}
4545

46-
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device) {
47-
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device);
46+
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device, bool hardware_compatible) {
47+
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device, hardware_compatible);
4848

4949
// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
5050
// TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
@@ -59,7 +59,9 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de
5959
}
6060

6161
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
62-
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
62+
LOG_DEBUG(
63+
"Attempting to run engine (ID: " << compiled_engine->name
64+
<< "); Hardware Compatible: " << compiled_engine->hardware_compatible);
6365

6466
if (compiled_engine->profile_execution) {
6567
std::stringstream ss;
@@ -89,7 +91,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8991

9092
if (is_switch_required(curr_device, compiled_engine->device_info)) {
9193
// Scan through available CUDA devices and set the CUDA device context correctly
92-
RTDevice device = select_rt_device(compiled_engine->device_info, curr_device);
94+
RTDevice device =
95+
select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible);
9396
set_rt_device(device);
9497

9598
// Target device is new device

core/runtime/register_jit_hooks.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
101101
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
102102
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
103103
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
104+
serialize_info[HARDWARE_COMPATIBLE] = self->hardware_compatible ? "1" : "0";
105+
106+
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
104107

105108
return serialize_info;
106109
},

core/runtime/runtime.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10-
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
10+
c10::optional<RTDevice> get_most_compatible_device(
11+
const RTDevice& target_device,
12+
const RTDevice& curr_device,
13+
bool hardware_compatible) {
1114
LOG_DEBUG("Target Device: " << target_device);
12-
auto device_options = find_compatible_devices(target_device);
15+
auto device_options = find_compatible_devices(target_device, hardware_compatible);
1316
RTDevice current_device;
1417
if (current_device.id == -1) {
1518
current_device = get_current_device();
@@ -28,7 +31,8 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
2831
dev_list << "[" << std::endl;
2932
for (auto device : device_options) {
3033
dev_list << " " << device << ',' << std::endl;
31-
if (device.device_name == target_device.device_name) {
34+
// If the model is hardware compatible, any compatible device should be valid
35+
if ((device.device_name == target_device.device_name) || hardware_compatible) {
3236
// First priority is selecting a candidate which agrees with the current device ID
3337
// If such a device is found, we can select it and break out of the loop
3438
if (device.id == current_device.id && best_match.id != current_device.id) {
@@ -58,7 +62,7 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
5862
}
5963
}
6064

61-
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device) {
65+
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible) {
6266
auto dla_supported = get_dla_supported_SMs();
6367
auto device_list = get_available_device_list().get_devices();
6468

@@ -74,7 +78,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device) {
7478
} else if (target_device.device_type == nvinfer1::DeviceType::kGPU) {
7579
auto target_dev_cc = target_device.getSMCapability();
7680
// If the SM Capabilities match, should be good enough to run
77-
if (poss_dev_cc == target_dev_cc) {
81+
// If hardware compatibility mode is enabled and the SM is at least 80, device is valid
82+
if ((poss_dev_cc == target_dev_cc) || (hardware_compatible && std::stoi(poss_dev_cc) >= 8)) {
7883
compatible_devices.push_back(device.second);
7984
}
8085
} else {

core/runtime/runtime.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,23 @@ namespace core {
1515
namespace runtime {
1616

1717
using EngineID = int64_t;
18-
const std::string ABI_VERSION = "4";
18+
const std::string ABI_VERSION = "5";
1919
typedef enum {
2020
ABI_TARGET_IDX = 0,
2121
NAME_IDX,
2222
DEVICE_IDX,
2323
ENGINE_IDX,
2424
INPUT_BINDING_NAMES_IDX,
2525
OUTPUT_BINDING_NAMES_IDX,
26+
HARDWARE_COMPATIBLE,
2627
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
2728
} SerializedInfoIndex;
2829

2930
c10::optional<RTDevice> get_most_compatible_device(
3031
const RTDevice& target_device,
31-
const RTDevice& curr_device = RTDevice());
32-
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
32+
const RTDevice& curr_device = RTDevice(),
33+
bool hardware_compatible = false);
34+
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible);
3335

3436
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3537

py/torch_tensorrt/dynamo/_compiler.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch.export import ExportedProgram
109
from torch_tensorrt._Device import Device
1110
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
@@ -17,6 +16,7 @@
1716
DEBUG,
1817
DEVICE,
1918
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
19+
HARDWARE_COMPATIBLE,
2020
MAX_AUX_STREAMS,
2121
MIN_BLOCK_SIZE,
2222
OPTIMIZATION_LEVEL,
@@ -43,6 +43,8 @@
4343
to_torch_tensorrt_device,
4444
)
4545

46+
import torch_tensorrt
47+
4648
logger = logging.getLogger(__name__)
4749

4850

@@ -75,6 +77,7 @@ def compile(
7577
use_python_runtime: bool = USE_PYTHON_RUNTIME,
7678
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
7779
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
80+
hardware_compatible: bool = HARDWARE_COMPATIBLE,
7881
**kwargs: Any,
7982
) -> torch.fx.GraphModule:
8083
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -131,6 +134,7 @@ def compile(
131134
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
132135
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
133136
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
137+
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
134138
**kwargs: Any,
135139
Returns:
136140
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -199,6 +203,7 @@ def compile(
199203
"use_fast_partitioner": use_fast_partitioner,
200204
"enable_experimental_decompositions": enable_experimental_decompositions,
201205
"require_full_compilation": require_full_compilation,
206+
"hardware_compatible": hardware_compatible,
202207
}
203208

204209
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
USE_FAST_PARTITIONER = True
1616
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
1717
REQUIRE_FULL_COMPILATION = False
18+
HARDWARE_COMPATIBLE = False
1819

1920

2021
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch_tensorrt.dynamo._defaults import (
77
DEBUG,
88
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
9+
HARDWARE_COMPATIBLE,
910
MAX_AUX_STREAMS,
1011
MIN_BLOCK_SIZE,
1112
OPTIMIZATION_LEVEL,
@@ -46,6 +47,7 @@ class CompilationSettings:
4647
device (Device): GPU to compile the model on
4748
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
4849
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
50+
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
4951
"""
5052

5153
precision: torch.dtype = PRECISION
@@ -63,3 +65,4 @@ class CompilationSettings:
6365
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
6466
device: Device = field(default_factory=default_device)
6567
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
68+
hardware_compatible: bool = HARDWARE_COMPATIBLE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7+
8+
# @manual=//deeplearning/trt/python:py_tensorrt
9+
import tensorrt as trt
710
import torch
811
import torch.fx
912
from torch.fx.node import _get_qualified_name
@@ -23,8 +26,6 @@
2326
from torch_tensorrt.fx.observer import Observer
2427
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2528

26-
# @manual=//deeplearning/trt/python:py_tensorrt
27-
import tensorrt as trt
2829
from packaging import version
2930

3031
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -118,7 +119,6 @@ def validate_conversion(self) -> Set[str]:
118119

119120
def run(
120121
self,
121-
workspace_size: int = 0,
122122
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
123123
sparse_weights: bool = False,
124124
disable_tf32: bool = False,
@@ -128,14 +128,10 @@ def run(
128128
timing_cache: Optional[trt.ITimingCache] = None,
129129
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
130130
tactic_sources: Optional[int] = None,
131-
max_aux_streams: Optional[int] = None,
132-
version_compatible: bool = False,
133-
optimization_level: Optional[int] = None,
134131
) -> TRTInterpreterResult:
135132
"""
136133
Build TensorRT engine with some configs.
137134
Args:
138-
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
139135
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
140136
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
141137
force_fp32_output: force output to be fp32
@@ -172,9 +168,10 @@ def run(
172168

173169
builder_config = self.builder.create_builder_config()
174170

175-
if workspace_size != 0:
171+
if self.ctx.compilation_settings.workspace_size != 0:
176172
builder_config.set_memory_pool_limit(
177-
trt.MemoryPoolType.WORKSPACE, workspace_size
173+
trt.MemoryPoolType.WORKSPACE,
174+
self.ctx.compilation_settings.workspace_size,
178175
)
179176

180177
cache = None
@@ -193,15 +190,28 @@ def run(
193190
)
194191

195192
if version.parse(trt.__version__) >= version.parse("8.6"):
196-
if max_aux_streams is not None:
197-
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
198-
builder_config.max_aux_streams = max_aux_streams
199-
if version_compatible:
193+
if self.ctx.compilation_settings.max_aux_streams is not None:
194+
_LOGGER.info(
195+
f"Setting max aux streams to {self.ctx.compilation_settings.max_aux_streams}"
196+
)
197+
builder_config.max_aux_streams = (
198+
self.ctx.compilation_settings.max_aux_streams
199+
)
200+
if self.ctx.compilation_settings.version_compatible:
200201
_LOGGER.info("Using version compatible")
201202
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
202-
if optimization_level is not None:
203-
_LOGGER.info(f"Using optimization level {optimization_level}")
204-
builder_config.builder_optimization_level = optimization_level
203+
if self.ctx.compilation_settings.hardware_compatible:
204+
_LOGGER.info("Using hardware compatible")
205+
builder_config.hardware_compatibility_level = (
206+
trt.HardwareCompatibilityLevel.AMPERE_PLUS
207+
)
208+
if self.ctx.compilation_settings.optimization_level is not None:
209+
_LOGGER.info(
210+
f"Using optimization level {self.ctx.compilation_settings.optimization_level}"
211+
)
212+
builder_config.builder_optimization_level = (
213+
self.ctx.compilation_settings.optimization_level
214+
)
205215

206216
if precision == torch.float16:
207217
builder_config.set_flag(trt.BuilderFlag.FP16)

py/torch_tensorrt/dynamo/conversion/_conversion.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo._settings import CompilationSettings
910
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112
from torch_tensorrt.dynamo.utils import get_torch_inputs
1213

13-
import tensorrt as trt
14-
1514

1615
def convert_module(
1716
module: torch.fx.GraphModule,
@@ -55,16 +54,12 @@ def convert_module(
5554
compilation_settings=settings,
5655
)
5756
interpreter_result = interpreter.run(
58-
workspace_size=settings.workspace_size,
5957
precision=settings.precision,
6058
profiling_verbosity=(
6159
trt.ProfilingVerbosity.VERBOSE
6260
if settings.debug
6361
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
6462
),
65-
max_aux_streams=settings.max_aux_streams,
66-
version_compatible=settings.version_compatible,
67-
optimization_level=settings.optimization_level,
6863
)
6964

7065
if settings.use_python_runtime:
@@ -86,4 +81,5 @@ def convert_module(
8681
input_binding_names=list(interpreter_result.input_names),
8782
output_binding_names=list(interpreter_result.output_names),
8883
target_device=settings.device,
84+
hardware_compatible=settings.hardware_compatible,
8985
)

0 commit comments

Comments
 (0)