Skip to content

Commit aad7d06

Browse files
authored
feat: Safety Mode for Runtime (#2512)
1 parent 6338bd5 commit aad7d06

File tree

16 files changed

+420
-15
lines changed

16 files changed

+420
-15
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TRTEngine::TRTEngine(
5252
auto most_compatible_device = get_most_compatible_device(cuda_device);
5353
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
5454
device_info = most_compatible_device.value();
55+
multi_gpu_device_check();
5556
set_rt_device(device_info);
5657

5758
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
7474
LOG_INFO("" << log_info);
7575
}
7676

77-
{
77+
if (MULTI_DEVICE_SAFE_MODE) {
7878
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
7979
if (compiled_engine->profile_execution) {
8080
device_profiler_guard =

core/runtime/register_jit_hooks.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ TORCH_LIBRARY(tensorrt, m) {
114114
m.def("execute_engine", execute_engine);
115115
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
116116
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
117+
m.def("get_multi_device_safe_mode", []() -> bool { return MULTI_DEVICE_SAFE_MODE; });
118+
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
119+
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
120+
});
117121
}
118122

119123
} // namespace

core/runtime/runtime.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ namespace torch_tensorrt {
77
namespace core {
88
namespace runtime {
99

10+
bool MULTI_DEVICE_SAFE_MODE = false;
11+
1012
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
1113
LOG_DEBUG("Target Device: " << target_device);
1214
auto device_options = find_compatible_devices(target_device);
@@ -31,13 +33,13 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
3133
if (device.device_name == target_device.device_name) {
3234
// First priority is selecting a candidate which agrees with the current device ID
3335
// If such a device is found, we can select it and break out of the loop
34-
if (device.id == current_device.id && best_match.id != current_device.id) {
36+
if (device.id == current_device.id) {
3537
best_match = device;
3638
break;
3739
}
3840
// Second priority is selecting a candidate which agrees with the target device ID
3941
// At deserialization time, the current device and target device may not agree
40-
else if (device.id == target_device.id && best_match.id != target_device.id) {
42+
else if (device.id == target_device.id) {
4143
best_match = device;
4244
}
4345
// If no such GPU ID is found, select the first available candidate GPU
@@ -103,6 +105,17 @@ RTDevice get_current_device() {
103105
return RTDevice(device_id, nvinfer1::DeviceType::kGPU);
104106
}
105107

108+
void multi_gpu_device_check() {
109+
// If multi-device safe mode is disabled and more than 1 device is registered on the machine, warn user
110+
if (!(MULTI_DEVICE_SAFE_MODE) && get_available_device_list().get_devices().size() > 1) {
111+
LOG_WARNING(
112+
"Detected this engine is being instantitated in a multi-GPU system with "
113+
<< "multi-device safe mode disabled. For more on the implications of this "
114+
<< "as well as workarounds, see the linked documentation "
115+
<< "(https://pytorch.org/TensorRT/user_guide/runtime.html#multi-device-safe-mode)");
116+
}
117+
}
118+
106119
namespace {
107120
static DeviceList cuda_device_list;
108121
}

core/runtime/runtime.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace runtime {
1616

1717
using EngineID = int64_t;
1818
const std::string ABI_VERSION = "4";
19+
extern bool MULTI_DEVICE_SAFE_MODE;
1920
typedef enum {
2021
ABI_TARGET_IDX = 0,
2122
NAME_IDX,
@@ -33,6 +34,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
3334

3435
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3536

37+
void multi_gpu_device_check();
38+
3639
class DeviceList {
3740
using DeviceMap = std::unordered_map<int, RTDevice>;
3841
DeviceMap device_list;

docsrc/user_guide/runtime.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,37 @@ Plugin Library
3434
In the case you use Torch-TensorRT as a converter to a TensorRT engine and your engine uses plugins provided by Torch-TensorRT, Torch-TensorRT
3535
ships the library ``libtorchtrt_plugins.so`` which contains the implementation of the TensorRT plugins used by Torch-TensorRT during
3636
compilation. This library can be ``DL_OPEN`` or ``LD_PRELOAD`` similar to other TensorRT plugin libraries.
37+
38+
Multi Device Safe Mode
39+
---------------
40+
41+
Multi-device safe mode is a setting in Torch-TensorRT which allows the user to determine whether
42+
the runtime checks for device consistency prior to every inference call.
43+
44+
There is a non-negligible, fixed cost per-inference call when multi-device safe mode is enabled, which is why
45+
it is now disabled by default. It can be controlled via the following convenience function which
46+
doubles as a context manager.
47+
48+
.. code-block:: python
49+
50+
# Enables Multi Device Safe Mode
51+
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
52+
53+
# Disables Multi Device Safe Mode [Default Behavior]
54+
torch_tensorrt.runtime.set_multi_device_safe_mode(False)
55+
56+
# Enables Multi Device Safe Mode, then resets the safe mode to its prior setting
57+
with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
58+
...
59+
60+
TensorRT requires that each engine be associated with the CUDA context in the active thread from which it is invoked.
61+
Therefore, if the device were to change in the active thread, which may be the case when invoking
62+
engines on multiple GPUs from the same Python process, safe mode will cause Torch-TensorRT to display
63+
an alert and switch GPUs accordingly. If safe mode were not enabled, there could be a mismatch in the engine
64+
device and CUDA context device, which could lead the program to crash.
65+
66+
One technique for managing multiple TRT engines on different GPUs while not sacrificing performance for
67+
multi-device safe mode is to use Python threads. Each thread is responsible for all of the TRT engines
68+
on a single GPU, and the default CUDA device on each thread corresponds to the GPU for which it is
69+
responsible (can be set via ``torch.cuda.set_device(...)``). In this way, multiple threads can be used in the same
70+
Python script without needing to switch CUDA contexts and incur performance overhead.

py/torch_tensorrt/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,17 @@ def _find_lib(name: str, paths: List[str]) -> str:
8585
from torch_tensorrt._Device import Device # noqa: F401
8686
from torch_tensorrt._enums import * # noqa: F403
8787
from torch_tensorrt._Input import Input # noqa: F401
88-
from torch_tensorrt.logging import *
89-
from torch_tensorrt.ptq import *
9088
from torch_tensorrt._utils import * # noqa: F403
9189
from torch_tensorrt._utils import sanitized_torch_version
90+
from torch_tensorrt.logging import *
91+
from torch_tensorrt.ptq import *
92+
from torch_tensorrt.runtime import * # noqa: F403
9293

9394
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
94-
from torch_tensorrt import dynamo # noqa: F401
9595
from torch_tensorrt.dynamo import backend # noqa: F401
9696

97+
from torch_tensorrt import dynamo # noqa: F401
98+
9799

98100
def _register_with_torch() -> None:
99101
trtorch_dir = os.path.dirname(__file__)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def convert_module(
6060
engine=interpreter_result.engine,
6161
input_names=list(interpreter_result.input_names),
6262
output_names=list(interpreter_result.output_names),
63+
target_device=settings.device,
64+
profiling_enabled=settings.debug,
6365
)
6466

6567
else:

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def is_node_supported(
4242
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4343

4444
if (
45-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
45+
node in CONVERTERS or node.op == "get_attr"
4646
) and node_name not in self.torch_executed_ops:
4747
# If node is a proper, supported computational node, store the operator
48-
if not node.is_impure():
48+
if not node.is_impure() and node.op != "get_attr":
4949
if node_name not in self.supported_operators:
5050
self.supported_operators[node_name] = 1
5151
else:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ def is_node_supported(
150150
node_name = ConverterRegistry.qualified_name_or_str(node.target)
151151

152152
if (
153-
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
153+
node in CONVERTERS or node.op == "get_attr"
154154
) and node_name not in self.torch_executed_ops:
155155
# If node is a proper, supported computational node, store the operator
156-
if not node.is_impure():
156+
if not node.is_impure() and node.op != "get_attr":
157157
if node_name not in self.supported_operators:
158158
self.supported_operators[node_name] = 1
159159
else:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
from __future__ import annotations
22

33
import logging
4+
from contextlib import nullcontext
45
from typing import Any, Dict, List, Optional, Sequence, Tuple
56

67
import tensorrt as trt
78
import torch
89
from torch.nn import Module
10+
from torch_tensorrt._Device import Device
11+
from torch_tensorrt.dynamo.runtime.tools import (
12+
_is_switch_required,
13+
_select_rt_device,
14+
multi_gpu_device_check,
15+
)
916
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1017

18+
import torch_tensorrt
19+
1120
logger = logging.getLogger(__name__)
1221

1322

@@ -23,13 +32,26 @@ def __init__(
2332
engine: trt.ICudaEngine,
2433
input_names: Optional[List[str]] = None,
2534
output_names: Optional[List[str]] = None,
35+
target_device: Device = Device._current_device(),
36+
profiling_enabled: Optional[bool] = None,
2637
):
2738
super(PythonTorchTensorRTModule, self).__init__()
2839
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
40+
41+
# Run multi-gpu device check to validate engine instantiation
42+
multi_gpu_device_check()
43+
2944
self.engine = engine
3045
self.input_names = input_names if input_names is not None else []
3146
self.output_names = output_names if output_names is not None else []
3247
self.initialized = False
48+
self.target_device_id = target_device.gpu_id
49+
self.target_device_properties = torch.cuda.get_device_properties(
50+
self.target_device_id
51+
)
52+
self.profiling_enabled = (
53+
profiling_enabled if profiling_enabled is not None else False
54+
)
3355
self._initialize()
3456

3557
def _initialize(self) -> None:
@@ -119,6 +141,9 @@ def _load_from_state_dict(
119141
) -> None:
120142
engine_bytes = state_dict[prefix + "engine"]
121143

144+
# Run multi-gpu device check to validate engine instantiation
145+
multi_gpu_device_check()
146+
122147
logger = trt.Logger()
123148
runtime = trt.Runtime(logger)
124149
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
@@ -141,15 +166,43 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
141166
if self.engine:
142167
self.context = self.engine.create_execution_context()
143168

144-
def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
169+
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
145170
with torch.autograd.profiler.record_function(
146171
"PythonTorchTensorRTModule:Forward"
147-
):
172+
) if self.profiling_enabled else nullcontext():
148173
self._check_initialized()
149174

175+
# If in safe mode, check at each iteration for for whether a switch is required
176+
if (
177+
torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
178+
):
179+
curr_device_id = torch.cuda.current_device()
180+
curr_device_properties = torch.cuda.get_device_properties(
181+
curr_device_id
182+
)
183+
logger.debug(f"Current Device: cuda:{curr_device_id}")
184+
185+
# If a switch is required, move all inputs to new device and set as active device
186+
if _is_switch_required(
187+
curr_device_id,
188+
self.target_device_id,
189+
curr_device_properties,
190+
self.target_device_properties,
191+
):
192+
device_id, _ = _select_rt_device(
193+
curr_device_id,
194+
self.target_device_id,
195+
self.target_device_properties,
196+
)
197+
device = torch.device(device_id)
198+
torch.cuda.set_device(device_id)
199+
200+
inputs = tuple([tensor.to(device) for tensor in inputs])
201+
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
202+
150203
with torch.autograd.profiler.record_function(
151204
"PythonTorchTensorRTModule:ProcessInputs"
152-
):
205+
) if self.profiling_enabled else nullcontext():
153206
assert len(inputs) == len(
154207
self.input_names
155208
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
@@ -188,7 +241,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
188241

189242
with torch.autograd.profiler.record_function(
190243
"PythonTorchTensorRTModule:ProcessOutputs"
191-
):
244+
) if self.profiling_enabled else nullcontext():
192245
# create output tensors
193246
outputs: List[torch.Tensor] = []
194247

@@ -215,7 +268,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
215268

216269
with torch.autograd.profiler.record_function(
217270
"PythonTorchTensorRTModule:TensorRTRuntime"
218-
):
271+
) if self.profiling_enabled else nullcontext():
219272
self.context.execute_async_v2(
220273
bindings, torch.cuda.current_stream().cuda_stream
221274
)
@@ -235,6 +288,8 @@ def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
235288
if not self.context.profiler:
236289
self.context.profiler = trt.Profiler() if profiler is None else profiler
237290

291+
self.profiling_enabled = True
292+
238293
def disable_profiling(self) -> None:
239294
"""
240295
Disable TensorRT profiling.
@@ -244,6 +299,7 @@ def disable_profiling(self) -> None:
244299
torch.cuda.synchronize()
245300
del self.context
246301
self.context = self.engine.create_execution_context()
302+
self.profiling_enabled = False
247303

248304
def get_layer_info(self) -> str:
249305
"""

0 commit comments

Comments
 (0)