@@ -31,18 +31,18 @@ auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter
31
31
32
32
nvinfer1::ITensor* out0;
33
33
nvinfer1::ITensor* out1;
34
- if (!keep_dims){
34
+ if (!keep_dims) {
35
35
if (topk_dims[dim] == 1 ) {
36
- auto squeeze_layer = ctx->net ->addShuffle (*topk_layer->getOutput (0 ));
37
- squeeze_layer->setReshapeDimensions (util::squeezeDims (topk_layer->getOutput (0 )->getDimensions (), dim));
38
- TORCHTRT_CHECK (squeeze_layer, " Unable to create squeeze_layer layer from node: " << *n);
39
- out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], squeeze_layer->getOutput (0 ));
40
-
41
- auto squeeze_layer_indices = ctx->net ->addShuffle (*topk_layer->getOutput (1 ));
42
- squeeze_layer_indices->setReshapeDimensions (util::squeezeDims (topk_layer->getOutput (1 )->getDimensions (), dim));
43
- TORCHTRT_CHECK (squeeze_layer_indices, " Unable to create squeeze_layer_indices layer from node: " << *n);
44
- out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], squeeze_layer_indices->getOutput (0 ));
36
+ auto squeeze_layer = ctx->net ->addShuffle (*topk_layer->getOutput (0 ));
37
+ squeeze_layer->setReshapeDimensions (util::squeezeDims (topk_layer->getOutput (0 )->getDimensions (), dim));
38
+ TORCHTRT_CHECK (squeeze_layer, " Unable to create squeeze_layer layer from node: " << *n);
39
+ out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], squeeze_layer->getOutput (0 ));
45
40
41
+ auto squeeze_layer_indices = ctx->net ->addShuffle (*topk_layer->getOutput (1 ));
42
+ squeeze_layer_indices->setReshapeDimensions (
43
+ util::squeezeDims (topk_layer->getOutput (1 )->getDimensions (), dim));
44
+ TORCHTRT_CHECK (squeeze_layer_indices, " Unable to create squeeze_layer_indices layer from node: " << *n);
45
+ out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], squeeze_layer_indices->getOutput (0 ));
46
46
}
47
47
} else {
48
48
out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], topk_layer->getOutput (0 ));
0 commit comments