You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: py/torch_tensorrt/_TRTModuleNext.py
+13-4
Original file line number
Diff line number
Diff line change
@@ -31,8 +31,8 @@ class TRTModuleNext(torch.nn.Module):
31
31
32
32
def__init__(
33
33
self,
34
+
serialized_engine: bytearray,
34
35
name: str="",
35
-
serialized_engine: bytearray=bytearray(),
36
36
input_binding_names: List[str] = [],
37
37
output_binding_names: List[str] = [],
38
38
target_device: Device=Device._current_device(),
@@ -42,6 +42,11 @@ def __init__(
42
42
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
43
43
a PyTorch ``torch.nn.Module`` around it.
44
44
45
+
If binding names are not provided, it is assumed that the engine binding names follow the following convention:
46
+
47
+
- [symbol].[index in input / output array]
48
+
- ex. [x.0, x.1, x.2] -> [y.0]
49
+
45
50
Args:
46
51
name (str): Name for module
47
52
serialized_engine (bytearray): Serialized TensorRT engine in the form of a bytearray
@@ -51,15 +56,15 @@ def __init__(
51
56
52
57
Example:
53
58
54
-
..code-block:: python
59
+
..code-block:: py
55
60
56
61
with io.BytesIO() as engine_bytes:
57
62
engine_bytes.write(trt_engine.serialize())
58
63
engine_str = engine_bytes.getvalue()
59
64
60
65
trt_module = TRTModule(
61
-
engine_name="my_engine",
62
-
serialized_engine=engine_str,
66
+
engine_str,
67
+
engine_name="my_module",
63
68
input_names=["x"],
64
69
output_names=["output"],
65
70
)
@@ -69,6 +74,10 @@ def __init__(
69
74
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
70
75
)
71
76
super(TRTModuleNext, self).__init__()
77
+
78
+
ifnotisinstance(serialized_engine, bytearray):
79
+
ValueError("Expected serialized engine as bytearray")
0 commit comments