@@ -18,17 +18,36 @@ auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter
18
18
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19
19
auto self = args[0 ].ITensorOrFreeze (ctx);
20
20
auto dim = args[1 ].unwrapToInt ();
21
+ auto keep_dims = args[2 ].unwrapToBool ();
21
22
auto selfDim = util::toVec (self->getDimensions ());
22
23
if (dim < 0 ) {
23
24
dim = selfDim.size () + dim;
24
25
}
25
26
uint32_t shiftDim = 1 << dim;
26
27
auto TopKOperation = nvinfer1::TopKOperation::kMAX ;
27
- auto new_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
28
- TORCHTRT_CHECK (new_layer, " Unable to create max layer from node: " << *n);
28
+ auto topk_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
29
+ TORCHTRT_CHECK (topk_layer, " Unable to create max layer from node: " << *n);
30
+ auto topk_dims = util::toVec (topk_layer->getOutput (0 )->getDimensions ());
29
31
30
- auto out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
31
- auto out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], new_layer->getOutput (1 ));
32
+ nvinfer1::ITensor* out0;
33
+ nvinfer1::ITensor* out1;
34
+ if (!keep_dims){
35
+ if (topk_dims[dim] == 1 ) {
36
+ auto squeeze_layer = ctx->net ->addShuffle (*topk_layer->getOutput (0 ));
37
+ squeeze_layer->setReshapeDimensions (util::squeezeDims (topk_layer->getOutput (0 )->getDimensions (), dim));
38
+ TORCHTRT_CHECK (squeeze_layer, " Unable to create squeeze_layer layer from node: " << *n);
39
+ out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], squeeze_layer->getOutput (0 ));
40
+
41
+ auto squeeze_layer_indices = ctx->net ->addShuffle (*topk_layer->getOutput (1 ));
42
+ squeeze_layer_indices->setReshapeDimensions (util::squeezeDims (topk_layer->getOutput (1 )->getDimensions (), dim));
43
+ TORCHTRT_CHECK (squeeze_layer_indices, " Unable to create squeeze_layer_indices layer from node: " << *n);
44
+ out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], squeeze_layer_indices->getOutput (0 ));
45
+
46
+ }
47
+ } else {
48
+ out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], topk_layer->getOutput (0 ));
49
+ out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], topk_layer->getOutput (1 ));
50
+ }
32
51
33
52
LOG_DEBUG (" Output tensor(0) shape: " << out0->getDimensions ());
34
53
LOG_DEBUG (" Output tensor(1) shape: " << out1->getDimensions ());
0 commit comments