@@ -33,11 +33,8 @@ def generate_plugin(plugin_name: str):
33
33
# helper function that generates the required signature based on the torch operation
34
34
def generate_signature (torch_op ):
35
35
schema = torch_op ._schemas ["" ]
36
- tensor_args = []
37
- arg_list = []
38
36
39
- args = []
40
- kwargs = []
37
+ arg_list = []
41
38
42
39
register_func_annotation = {}
43
40
impl_func_annotation = {}
@@ -56,7 +53,6 @@ def generate_signature(torch_op):
56
53
# - torch._C.ClassType
57
54
58
55
if arg .type .isSubtypeOf (torch ._C .TensorType .get ()):
59
- tensor_args .append (arg )
60
56
register_func_annotation [arg .name ] = trtp .TensorDesc
61
57
impl_func_annotation [arg .name ] = trtp .Tensor
62
58
elif arg .type .isSubtypeOf (torch ._C .FloatType .get ()):
@@ -74,40 +70,32 @@ def generate_signature(torch_op):
74
70
else :
75
71
raise ValueError ("arg type is not handled" )
76
72
77
- if arg .default_value is None :
78
- args .append (arg .name )
79
- else :
80
- kwargs .append (f"{ arg .name } = { arg .default_value } " )
81
-
82
73
input_signature = ", " .join (arg_list )
74
+
83
75
plugin_signature = f"def add_plugin_desc({ input_signature } ):"
84
- args_input = ", " .join (args )
85
- kwargs_input = ", " .join (kwargs )
86
76
87
77
plugin_impl_arg_list = arg_list
88
78
plugin_impl_arg_list .append ("outputs" )
89
79
plugin_impl_arg_list .append ("stream" )
90
80
plugin_impl_input = ", " .join (plugin_impl_arg_list )
91
- plugin_impl_signagture = f"def add_plugin_impl({ plugin_impl_input } ):"
81
+ plugin_impl_signature = f"def add_plugin_impl({ plugin_impl_input } ):"
92
82
93
83
register_func_annotation ["return" ] = Tuple [trtp .TensorDesc ]
94
84
95
85
impl_func_annotation ["outputs" ] = Tuple [trtp .Tensor ]
96
86
impl_func_annotation ["stream" ] = int
97
87
98
88
return (
99
- args_input ,
100
- kwargs_input ,
89
+ input_signature ,
101
90
plugin_signature ,
102
- plugin_impl_signagture ,
91
+ plugin_impl_signature ,
103
92
register_func_annotation ,
104
93
impl_func_annotation ,
105
94
)
106
95
107
96
# Use the helper function to get the required signatures
108
97
(
109
- args_input ,
110
- kwargs_input ,
98
+ input_signature ,
111
99
plugin_signature ,
112
100
plugin_impl_signature ,
113
101
register_func_annotation ,
@@ -118,8 +106,11 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
118
106
shape_env = ShapeEnv ()
119
107
fake_mode = FakeTensorMode (shape_env = shape_env )
120
108
syms_args = []
121
- for arg in args :
122
- sample = {f"{ i } " : 5 for i in range (arg .ndim )}
109
+ tensor_args = [elem for elem in args if isinstance (elem , trtp .TensorDesc )]
110
+
111
+ for tensor_arg in tensor_args :
112
+
113
+ sample = {f"{ i } " : 5 for i in range (tensor_arg .ndim )}
123
114
syms_arg = [
124
115
mksym (shape_env , v , LocalSource (k ), DimDynamic .DYNAMIC )
125
116
for k , v in sample .items ()
@@ -142,16 +133,16 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
142
133
tuple (input_node_expr ), output .shape [i ].node .expr , "math"
143
134
)
144
135
145
- out_desc = args [0 ].like ()
136
+ out_desc = tensor_args [0 ].like ()
146
137
for i in range (out_desc .ndim ):
147
- input_shape_expr = [arg .shape_expr [i ] for arg in args ]
138
+ input_shape_expr = [tensor_arg .shape_expr [i ] for tensor_arg in tensor_args ]
148
139
out_desc .shape_expr [i ] = shape_calc_fns [i ](* input_shape_expr )
149
140
150
141
return (out_desc ,)
151
142
152
143
codegen_plugin = f"""
153
144
{ plugin_signature }
154
- return _generic_plugin_desc({ args_input } , { kwargs_input } )
145
+ return _generic_plugin_desc({ input_signature } )
155
146
"""
156
147
157
148
_LOGGER .warning (f"Plugin registration function: \n { codegen_plugin } " )
@@ -160,26 +151,35 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
160
151
161
152
globals ()["_generic_plugin_desc" ] = _generic_plugin_desc
162
153
163
- plugin = FunctionType (plugin_code .co_consts [0 ], globals (), "plugin" )
154
+ plugin = FunctionType (
155
+ plugin_code .co_consts [0 ],
156
+ globals (),
157
+ "plugin" ,
158
+ )
164
159
165
160
# Function annotation is required for dynamic function to work in TensorRT.Plugin
166
161
plugin .__annotations__ = register_func_annotation
167
162
168
163
trtp .register (plugin_name )(plugin )
169
164
170
165
def _generic_plugin_impl (outputs , stream , * args , ** kwargs ):
171
- in_tensors = [torch .as_tensor (i , device = "cuda" ) for i in args ]
166
+ tensor_args = [elem for elem in args if isinstance (elem , trtp .Tensor )]
167
+ print (args )
168
+ non_tensor_args = [elem for elem in args if not isinstance (elem , trtp .Tensor )]
169
+ in_tensors = [torch .as_tensor (i , device = "cuda" ) for i in tensor_args ]
172
170
173
171
dest_tensors = [torch .as_tensor (o , device = "cuda" ) for o in outputs ]
174
172
175
173
stream = torch .cuda .ExternalStream (stream )
176
174
with torch .cuda .stream (stream ):
177
- out_tensors = torch_op (* in_tensors , ** kwargs )
175
+ out_tensors = torch_op (* in_tensors , * non_tensor_args , ** kwargs )
176
+ if isinstance (out_tensors , torch .Tensor ):
177
+ out_tensors = (out_tensors ,)
178
178
[d .copy_ (o ) for (d , o ) in zip (dest_tensors , out_tensors )]
179
179
180
180
plugin_impl_func = f"""
181
181
{ plugin_impl_signature }
182
- _generic_plugin_impl(outputs, stream, { args_input } , { kwargs_input } )
182
+ _generic_plugin_impl(outputs, stream, { input_signature } )
183
183
"""
184
184
185
185
_LOGGER .warning (f"Plugin implementation function: \n { plugin_impl_func } " )
@@ -193,5 +193,3 @@ def _generic_plugin_impl(outputs, stream, *args, **kwargs):
193
193
plugin_impl .__annotations__ = impl_func_annotation
194
194
195
195
trtp .impl (plugin_name )(plugin_impl )
196
-
197
- return plugin
0 commit comments