Skip to content

Commit c8dc6e9

Browse files
committed
feat: support aten::conv1d and aten::conv_transpose1d
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 4d95b04 commit c8dc6e9

File tree

2 files changed

+148
-23
lines changed

2 files changed

+148
-23
lines changed

Diff for: core/conversion/converters/impl/conv_deconv.cpp

+62-23
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
13+
bool add_conv_deconv(
14+
ConversionCtx* ctx,
15+
const torch::jit::Node* n,
16+
args& args,
17+
nvinfer1::Dims& stride,
18+
nvinfer1::Dims& padding,
19+
nvinfer1::Dims& dilation,
20+
bool transposed,
21+
nvinfer1::Dims& out_padding,
22+
int64_t groups) {
1423
// Input to conv/deconv
1524
auto in = args[0].ITensor();
1625

17-
// Conv /deconv parameters
18-
auto stride = util::toDims(args[3].unwrapToIntList());
19-
auto padding = util::toDims(args[4].unwrapToIntList());
20-
auto dilation = util::toDims(args[5].unwrapToIntList());
21-
bool transposed = args[6].unwrapToBool();
22-
auto out_padding = util::toDims(args[7].unwrapToIntList());
23-
int64_t groups = args[8].unwrapToInt();
24-
2526
// Reshape the parameters to 2D if needed
2627
if (stride.nbDims == 1) {
2728
stride = util::unsqueezeDims(stride, 1, 1);
@@ -174,28 +175,66 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
174175
return true;
175176
}
176177

177-
auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
178-
.pattern({
179-
R"SIG(aten::_convolution(Tensor input, Tensor weight,
178+
auto conv_registrations TRTORCH_UNUSED =
179+
RegisterNodeConversionPatterns()
180+
.pattern({
181+
R"SIG(aten::_convolution(Tensor input, Tensor weight,
180182
Tensor? bias, int[] stride, int[] padding,
181183
int[] dilation, bool transposed,
182184
int[] output_padding, int groups, bool benchmark,
183185
bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG",
184-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
185-
return add_conv_deconv(ctx, n, args);
186-
}})
187-
.pattern({
188-
R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight,
186+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
187+
// Conv /deconv parameters
188+
auto stride = util::toDims(args[3].unwrapToIntList());
189+
auto padding = util::toDims(args[4].unwrapToIntList());
190+
auto dilation = util::toDims(args[5].unwrapToIntList());
191+
bool transposed = args[6].unwrapToBool();
192+
auto out_padding = util::toDims(args[7].unwrapToIntList());
193+
int64_t groups = args[8].unwrapToInt();
194+
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
195+
}})
196+
.pattern({
197+
R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight,
189198
Tensor? bias, int[] stride, int[] padding,
190199
int[] dilation, bool transposed,
191200
int[] output_padding, int groups, bool benchmark,
192201
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG",
193-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
194-
// This pattern is only matched for traced JIT models which do not
195-
// have allow_tf32 bool in the function signature. The TRT conversion
196-
// code is exactly same as the above call.
197-
return add_conv_deconv(ctx, n, args);
198-
}});
202+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
203+
// This pattern is only matched for traced JIT models which do not
204+
// have allow_tf32 bool in the function signature. The TRT conversion
205+
// code is exactly same as the above call.
206+
auto stride = util::toDims(args[3].unwrapToIntList());
207+
auto padding = util::toDims(args[4].unwrapToIntList());
208+
auto dilation = util::toDims(args[5].unwrapToIntList());
209+
bool transposed = args[6].unwrapToBool();
210+
auto out_padding = util::toDims(args[7].unwrapToIntList());
211+
int64_t groups = args[8].unwrapToInt();
212+
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
213+
}})
214+
.pattern(
215+
{R"SIG(aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor)SIG",
216+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
217+
// Conv /deconv parameters
218+
auto stride = util::toDims(args[3].unwrapToIntList());
219+
auto padding = util::toDims(args[4].unwrapToIntList());
220+
auto dilation = util::toDims(args[5].unwrapToIntList());
221+
bool transposed = false;
222+
nvinfer1::Dims out_padding{1, {0}};
223+
int64_t groups = args[6].unwrapToInt();
224+
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
225+
}})
226+
.pattern(
227+
{R"SIG(aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor)SIG",
228+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
229+
// Conv /deconv parameters
230+
auto stride = util::toDims(args[3].unwrapToIntList());
231+
auto padding = util::toDims(args[4].unwrapToIntList());
232+
auto out_padding = util::toDims(args[5].unwrapToIntList());
233+
bool transposed = true;
234+
int64_t groups = args[6].unwrapToInt();
235+
auto dilation = util::toDims(args[7].unwrapToIntList());
236+
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
237+
}});
199238
} // namespace
200239
} // namespace impl
201240
} // namespace converters

Diff for: tests/core/conversion/converters/test_conv_deconv.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
// int[] output_padding, int groups, bool benchmark,
1111
// bool deterministic, bool cudnn_enabled) -> (Tensor)
1212

13+
// aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) ->
14+
// Tensor
15+
16+
// aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding,
17+
// int groups, int[] dilation) -> Tensor
18+
1319
void conv_test_helper(std::string graph_ir) {
1420
auto g = std::make_shared<torch::jit::Graph>();
1521
torch::jit::parseIR(graph_ir, g.get());
@@ -116,6 +122,86 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
116122
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
117123
}
118124

125+
TEST(Converters, ATenConv1dConvertsCorrectly) {
126+
const auto graph = R"IR(
127+
graph(%0 : Tensor,
128+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
129+
%2 : Float(3)):
130+
%3 : int = prim::Constant[value=1]()
131+
%4 : int = prim::Constant[value=0]()
132+
%5 : int = prim::Constant[value=1]()
133+
%8 : int[] = prim::ListConstruct(%3)
134+
%9 : int[] = prim::ListConstruct(%4)
135+
%10 : int[] = prim::ListConstruct(%5)
136+
%12 : Tensor = aten::conv1d(%0, %1, %2, %8, %9, %10, %3)
137+
return (%12))IR";
138+
139+
auto g = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(graph, g.get());
141+
142+
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
143+
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
144+
auto b = at::randint(1, 10, {4}, {at::kCUDA});
145+
146+
auto jit_in = at::clone(in);
147+
auto jit_w = at::clone(w);
148+
auto jit_b = at::clone(b);
149+
150+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
151+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
152+
153+
auto trt_in = at::clone(in);
154+
auto trt_w = at::clone(w);
155+
auto trt_b = at::clone(b);
156+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
157+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
158+
159+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
160+
161+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
162+
}
163+
164+
TEST(Converters, ATenConvTranspose1dConvertsCorrectly) {
165+
const auto graph = R"IR(
166+
graph(%0 : Tensor,
167+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
168+
%2 : Float(3)):
169+
%3 : int = prim::Constant[value=1]()
170+
%4 : int = prim::Constant[value=0]()
171+
%5 : int = prim::Constant[value=1]()
172+
%6 : int = prim::Constant[value=0]()
173+
%8 : int[] = prim::ListConstruct(%3)
174+
%9 : int[] = prim::ListConstruct(%4)
175+
%10 : int[] = prim::ListConstruct(%5)
176+
%11 : int[] = prim::ListConstruct(%6)
177+
%12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %8, %9, %11, %3, %10)
178+
return (%12))IR";
179+
180+
auto g = std::make_shared<torch::jit::Graph>();
181+
torch::jit::parseIR(graph, g.get());
182+
183+
auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA});
184+
auto w = at::randint(1, 2, {8, 4, 3}, {at::kCUDA});
185+
auto b = at::randint(1, 10, {4}, {at::kCUDA});
186+
187+
auto jit_in = at::clone(in);
188+
auto jit_w = at::clone(w);
189+
auto jit_b = at::clone(b);
190+
191+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
192+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
193+
194+
auto trt_in = at::clone(in);
195+
auto trt_w = at::clone(w);
196+
auto trt_b = at::clone(b);
197+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
198+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
199+
200+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
201+
202+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
203+
}
204+
119205
TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
120206
const auto graph = R"IR(
121207
graph(%0 : Tensor,

0 commit comments

Comments
 (0)