@@ -10,18 +10,19 @@ namespace converters {
10
10
namespace impl {
11
11
namespace {
12
12
13
- bool add_conv_deconv (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
13
+ bool add_conv_deconv (
14
+ ConversionCtx* ctx,
15
+ const torch::jit::Node* n,
16
+ args& args,
17
+ nvinfer1::Dims& stride,
18
+ nvinfer1::Dims& padding,
19
+ nvinfer1::Dims& dilation,
20
+ bool transposed,
21
+ nvinfer1::Dims& out_padding,
22
+ int64_t groups) {
14
23
// Input to conv/deconv
15
24
auto in = args[0 ].ITensor ();
16
25
17
- // Conv /deconv parameters
18
- auto stride = util::toDims (args[3 ].unwrapToIntList ());
19
- auto padding = util::toDims (args[4 ].unwrapToIntList ());
20
- auto dilation = util::toDims (args[5 ].unwrapToIntList ());
21
- bool transposed = args[6 ].unwrapToBool ();
22
- auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
23
- int64_t groups = args[8 ].unwrapToInt ();
24
-
25
26
// Reshape the parameters to 2D if needed
26
27
if (stride.nbDims == 1 ) {
27
28
stride = util::unsqueezeDims (stride, 1 , 1 );
@@ -174,28 +175,66 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
174
175
return true ;
175
176
}
176
177
177
- auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
178
- .pattern({
179
- R"SIG( aten::_convolution(Tensor input, Tensor weight,
178
+ auto conv_registrations TRTORCH_UNUSED =
179
+ RegisterNodeConversionPatterns ()
180
+ .pattern({
181
+ R"SIG( aten::_convolution(Tensor input, Tensor weight,
180
182
Tensor? bias, int[] stride, int[] padding,
181
183
int[] dilation, bool transposed,
182
184
int[] output_padding, int groups, bool benchmark,
183
185
bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG" ,
184
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
185
- return add_conv_deconv (ctx, n, args);
186
- }})
187
- .pattern({
188
- R"SIG( aten::_convolution.deprecated(Tensor input, Tensor weight,
186
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
187
+ // Conv /deconv parameters
188
+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
189
+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
190
+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
191
+ bool transposed = args[6 ].unwrapToBool ();
192
+ auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
193
+ int64_t groups = args[8 ].unwrapToInt ();
194
+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
195
+ }})
196
+ .pattern({
197
+ R"SIG( aten::_convolution.deprecated(Tensor input, Tensor weight,
189
198
Tensor? bias, int[] stride, int[] padding,
190
199
int[] dilation, bool transposed,
191
200
int[] output_padding, int groups, bool benchmark,
192
201
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG" ,
193
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
194
- // This pattern is only matched for traced JIT models which do not
195
- // have allow_tf32 bool in the function signature. The TRT conversion
196
- // code is exactly same as the above call.
197
- return add_conv_deconv (ctx, n, args);
198
- }});
202
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
203
+ // This pattern is only matched for traced JIT models which do not
204
+ // have allow_tf32 bool in the function signature. The TRT conversion
205
+ // code is exactly same as the above call.
206
+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
207
+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
208
+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
209
+ bool transposed = args[6 ].unwrapToBool ();
210
+ auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
211
+ int64_t groups = args[8 ].unwrapToInt ();
212
+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
213
+ }})
214
+ .pattern(
215
+ {R"SIG( aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor)SIG" ,
216
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
217
+ // Conv /deconv parameters
218
+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
219
+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
220
+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
221
+ bool transposed = false ;
222
+ nvinfer1::Dims out_padding{1 , {0 }};
223
+ int64_t groups = args[6 ].unwrapToInt ();
224
+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
225
+ }})
226
+ .pattern(
227
+ {R"SIG( aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor)SIG" ,
228
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
229
+ // Conv /deconv parameters
230
+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
231
+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
232
+ auto out_padding = util::toDims (args[5 ].unwrapToIntList ());
233
+ bool transposed = true ;
234
+ int64_t groups = args[6 ].unwrapToInt ();
235
+ auto dilation = util::toDims (args[7 ].unwrapToIntList ());
236
+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
237
+ }});
199
238
} // namespace
200
239
} // namespace impl
201
240
} // namespace converters
0 commit comments