1
1
#include " core/conversion/converters/converters.h"
2
2
3
+ #include " torch/torch.h"
4
+
3
5
namespace trtorch {
4
6
namespace core {
5
7
namespace conversion {
@@ -8,23 +10,41 @@ namespace impl {
8
10
namespace {
9
11
10
12
static auto shuffle_registrations = RegisterNodeConversionPatterns()
11
- .pattern({
12
- " aten::reshape(Tensor self, int[] shape) -> (Tensor)" ,
13
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14
- auto in = args[0 ].ITensor ();
15
- auto new_shape = util::toDimsPad (args[1 ].unwrapToIntList (), 2 );
16
-
17
- auto shuffle = ctx->net ->addShuffle (*in);
18
- TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
19
- shuffle->setReshapeDimensions (new_shape);
20
- shuffle->setName (util::node_info (n).c_str ());
21
-
22
- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
23
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
24
-
25
- return true ;
26
- }
27
- });
13
+ .pattern({
14
+ " aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)" ,
15
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16
+ auto in = args[0 ].ITensor ();
17
+ auto start_dim = args[1 ].unwrapToInt ();
18
+ auto end_dim = args[2 ].unwrapToInt ();
19
+ auto in_shape = util::toVec (in->getDimensions ());
20
+ auto out_shape = torch::flatten (torch::rand (in_shape), start_dim, end_dim).sizes ();
21
+
22
+ auto shuffle = ctx->net ->addShuffle (*in);
23
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
24
+ shuffle->setReshapeDimensions (util::toDims (out_shape));
25
+ shuffle->setName (util::node_info (n).c_str ());
26
+
27
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
28
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
29
+ return true ;
30
+ }
31
+ }).pattern({
32
+ " aten::reshape(Tensor self, int[] shape) -> (Tensor)" ,
33
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
34
+ auto in = args[0 ].ITensor ();
35
+ auto new_shape = util::toDimsPad (args[1 ].unwrapToIntList (), 2 );
36
+
37
+ auto shuffle = ctx->net ->addShuffle (*in);
38
+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
39
+ shuffle->setReshapeDimensions (new_shape);
40
+ shuffle->setName (util::node_info (n).c_str ());
41
+
42
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
43
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
44
+
45
+ return true ;
46
+ }
47
+ });
28
48
} // namespace
29
49
} // namespace impl
30
50
} // namespace converters
0 commit comments