Skip to content

Commit d945eb9

Browse files
committed
feat(aten::flatten): Adds a converter for aten flatten since MM is the
preferred path now Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4acc3fd commit d945eb9

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

Diff for: core/conversion/converters/impl/shuffle.cpp

+37-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "core/conversion/converters/converters.h"
22

3+
#include "torch/torch.h"
4+
35
namespace trtorch {
46
namespace core {
57
namespace conversion {
@@ -8,23 +10,41 @@ namespace impl {
810
namespace {
911

1012
static auto shuffle_registrations = RegisterNodeConversionPatterns()
11-
.pattern({
12-
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
13-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14-
auto in = args[0].ITensor();
15-
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);
16-
17-
auto shuffle = ctx->net->addShuffle(*in);
18-
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
19-
shuffle->setReshapeDimensions(new_shape);
20-
shuffle->setName(util::node_info(n).c_str());
21-
22-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
23-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
24-
25-
return true;
26-
}
27-
});
13+
.pattern({
14+
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
15+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16+
auto in = args[0].ITensor();
17+
auto start_dim = args[1].unwrapToInt();
18+
auto end_dim = args[2].unwrapToInt();
19+
auto in_shape = util::toVec(in->getDimensions());
20+
auto out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes();
21+
22+
auto shuffle = ctx->net->addShuffle(*in);
23+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
24+
shuffle->setReshapeDimensions(util::toDims(out_shape));
25+
shuffle->setName(util::node_info(n).c_str());
26+
27+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
28+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
29+
return true;
30+
}
31+
}).pattern({
32+
"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
33+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
34+
auto in = args[0].ITensor();
35+
auto new_shape = util::toDimsPad(args[1].unwrapToIntList(), 2);
36+
37+
auto shuffle = ctx->net->addShuffle(*in);
38+
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
39+
shuffle->setReshapeDimensions(new_shape);
40+
shuffle->setName(util::node_info(n).c_str());
41+
42+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
43+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
44+
45+
return true;
46+
}
47+
});
2848
} // namespace
2949
} // namespace impl
3050
} // namespace converters

Diff for: tests/core/converters/test_shuffle.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,54 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7+
// TODO: IR Parser doesnt work well with neg numbers
8+
TEST(Converters, ATenFlattenConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%0 : Tensor):
11+
%1 : int = prim::Constant[value=0]()
12+
%2 : int = prim::Constant[value=1]()
13+
%3 : Tensor = aten::flatten(%0, %1, %2)
14+
return (%3))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
torch::jit::parseIR(graph, &*g);
18+
19+
auto in = at::randint(0, 5, {2, 3}, {at::kCUDA});
20+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
21+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
22+
23+
in = at::clone(in);
24+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
25+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
26+
auto trt = trt_results[0].reshape_as(jit_results[0]);
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
29+
}
30+
31+
// TODO: IR Parser doesnt work well with neg numbers
32+
TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
33+
const auto graph = R"IR(
34+
graph(%0 : Tensor):
35+
%1 : int = prim::Constant[value=1]()
36+
%2 : int = prim::Constant[value=2]()
37+
%3 : Tensor = aten::flatten(%0, %1, %2)
38+
return (%3))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, &*g);
42+
43+
auto in = at::randint(0, 5, {2, 3, 3}, {at::kCUDA});
44+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
45+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
46+
47+
in = at::clone(in);
48+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
49+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
50+
auto trt = trt_results[0].reshape_as(jit_results[0]);
51+
52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
53+
}
54+
755
TEST(Converters, ATenReshapeConvertsCorrectly) {
856
const auto graph = R"IR(
957
graph(%0 : Tensor):

0 commit comments

Comments
 (0)