@@ -14,35 +14,57 @@ namespace converters {
14
14
namespace impl {
15
15
namespace {
16
16
17
- auto squeeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
18
- {" aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))" ,
19
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20
- auto self = args[0 ].ITensorOrFreeze (ctx);
21
- auto dim = args[1 ].unwrapToInt ();
17
+ auto squeeze_registrations TORCHTRT_UNUSED =
18
+ RegisterNodeConversionPatterns ()
19
+ .pattern(
20
+ {" aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))" ,
21
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
22
+ auto self = args[0 ].ITensorOrFreeze (ctx);
23
+ auto dim = args[1 ].unwrapToInt ();
22
24
23
- auto selfDim = util::toVec (self->getDimensions ());
24
- if (dim < 0 ) {
25
- dim = selfDim.size () + dim;
26
- }
25
+ auto selfDim = util::toVec (self->getDimensions ());
26
+ if (dim < 0 ) {
27
+ dim = selfDim.size () + dim;
28
+ }
27
29
28
- if (selfDim[dim] != 1 ) {
29
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], self);
30
+ if (selfDim[dim] != 1 ) {
31
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], self);
30
32
31
- LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
33
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
32
34
33
- return true ;
34
- }
35
+ return true ;
36
+ }
35
37
36
- auto shuffle_layer = ctx->net ->addShuffle (*self);
37
- TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
38
- shuffle_layer->setReshapeDimensions (util::squeezeDims (self->getDimensions (), dim));
38
+ auto shuffle_layer = ctx->net ->addShuffle (*self);
39
+ TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
40
+ shuffle_layer->setReshapeDimensions (util::squeezeDims (self->getDimensions (), dim));
39
41
40
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_layer->getOutput (0 ));
42
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle_layer->getOutput (0 ));
41
43
42
- LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
44
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
43
45
44
- return true ;
45
- }});
46
+ return true ;
47
+ }})
48
+ .pattern(
49
+ {" aten::squeeze(Tensor(a) self) -> (Tensor(a))" ,
50
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
51
+ auto self = args[0 ].ITensorOrFreeze (ctx);
52
+ auto self_dims = self->getDimensions ();
53
+ auto out = self;
54
+ auto squeeze_dims = util::squeezeAllDims (self_dims);
55
+ if (squeeze_dims != self_dims) {
56
+ auto shuffle_layer = ctx->net ->addShuffle (*self);
57
+ TORCHTRT_CHECK (shuffle_layer, " Unable to create shuffle layer from node: " << *n);
58
+ shuffle_layer->setReshapeDimensions (squeeze_dims);
59
+ out = shuffle_layer->getOutput (0 );
60
+ }
61
+
62
+ auto trt_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out);
63
+
64
+ LOG_DEBUG (" Output tensor shape: " << trt_out->getDimensions ());
65
+
66
+ return true ;
67
+ }});
46
68
47
69
} // namespace
48
70
} // namespace impl
0 commit comments