Skip to content

Commit 4ab2856

Browse files
committed
refactor: Reorder the API since everything but the engine is optional
Also new destructor to order cleanup Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 71082d3 commit 4ab2856

File tree

6 files changed

+23
-8
lines changed

6 files changed

+23
-8
lines changed

core/runtime/TRTEngine.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ TRTEngine::TRTEngine(
145145
LOG_DEBUG(*this);
146146
}
147147

148+
TRTEngine::~TRTEngine() {
149+
exec_ctx.reset();
150+
cuda_engine.reset();
151+
rt.reset();
152+
}
153+
148154
void TRTEngine::disable_profiling() {
149155
torch::cuda::synchronize(device_info.id);
150156
profile_execution = false;

core/runtime/TRTEngine.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct TRTEngine : torch::CustomClassHolder {
3333
std::vector<std::string> in_binding_names = {}; // ITO: PYT IDX
3434
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX
3535

36-
~TRTEngine() = default;
36+
~TRTEngine();
3737
TRTEngine(
3838
const std::string& serialized_engine,
3939
const RTDevice& cuda_device,

py/torch_tensorrt/_TRTModuleNext.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class TRTModuleNext(torch.nn.Module):
3131

3232
def __init__(
3333
self,
34+
serialized_engine: bytearray,
3435
name: str = "",
35-
serialized_engine: bytearray = bytearray(),
3636
input_binding_names: List[str] = [],
3737
output_binding_names: List[str] = [],
3838
target_device: Device = Device._current_device(),
@@ -42,6 +42,11 @@ def __init__(
4242
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
4343
a PyTorch ``torch.nn.Module`` around it.
4444
45+
If binding names are not provided, it is assumed that the engine binding names follow the following convention:
46+
47+
- [symbol].[index in input / output array]
48+
- ex. [x.0, x.1, x.2] -> [y.0]
49+
4550
Args:
4651
name (str): Name for module
4752
serialized_engine (bytearray): Serialized TensorRT engine in the form of a bytearray
@@ -51,15 +56,15 @@ def __init__(
5156
5257
Example:
5358
54-
..code-block:: python
59+
..code-block:: py
5560
5661
with io.BytesIO() as engine_bytes:
5762
engine_bytes.write(trt_engine.serialize())
5863
engine_str = engine_bytes.getvalue()
5964
6065
trt_module = TRTModule(
61-
engine_name="my_engine",
62-
serialized_engine=engine_str,
66+
engine_str,
67+
engine_name="my_module",
6368
input_names=["x"],
6469
output_names=["output"],
6570
)
@@ -69,6 +74,10 @@ def __init__(
6974
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
7075
)
7176
super(TRTModuleNext, self).__init__()
77+
78+
if not isinstance(serialized_engine, bytearray):
79+
ValueError("Expected serialized engine as bytearray")
80+
7281
self.input_binding_names = input_binding_names
7382
self.output_binding_names = output_binding_names
7483
self.name = name

py/torch_tensorrt/fx/lower.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ def lower_pass(
185185
engine_str = engine_bytes.getvalue()
186186

187187
trt_module = TRTModuleNext(
188+
engine_str,
188189
name=module_name,
189-
serialized_engine=engine_str,
190190
input_binding_names=interp_res.input_names,
191191
output_binding_names=interp_res.output_names,
192192
target_device=Device(f"cuda:{torch.cuda.current_device()}"),

py/torch_tensorrt/fx/tools/trt_minimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def lower_mod_default(
3030
engine_str = engine_bytes.getvalue()
3131

3232
res_mod = TRTModuleNext(
33+
engine_str,
3334
name=str(type(mod)),
34-
serialized_engine=engine_str,
3535
input_binding_names=interpreter_result.input_names,
3636
output_binding_names=interpreter_result.output_names,
3737
target_device=Device(f"cuda:{torch.cuda.current_device()}"),

py/torch_tensorrt/fx/tools/trt_splitter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def _lower_model_to_backend(
100100
engine_str = engine_bytes.getvalue()
101101

102102
return TRTModuleNext(
103+
engine_str,
103104
name=str(type(mod)),
104-
serialized_engine=engine_str,
105105
input_binding_names=interpreter_result.input_names,
106106
output_binding_names=interpreter_result.output_names,
107107
target_device=Device(f"cuda:{torch.cuda.current_device()}"),

0 commit comments

Comments
 (0)