Skip to content

Commit 84f626a

Browse files
committed
fix: Fix keep_dims functionality for aten::max
Signed-off-by: dperi <[email protected]>
1 parent 91a92ca commit 84f626a

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

core/conversion/converters/impl/max.cpp

+23-4
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,36 @@ auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter
1818
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1919
auto self = args[0].ITensorOrFreeze(ctx);
2020
auto dim = args[1].unwrapToInt();
21+
auto keep_dims = args[2].unwrapToBool();
2122
auto selfDim = util::toVec(self->getDimensions());
2223
if (dim < 0) {
2324
dim = selfDim.size() + dim;
2425
}
2526
uint32_t shiftDim = 1 << dim;
2627
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
27-
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
28-
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);
28+
auto topk_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
29+
TORCHTRT_CHECK(topk_layer, "Unable to create max layer from node: " << *n);
30+
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());
2931

30-
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
31-
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
32+
nvinfer1::ITensor* out0;
33+
nvinfer1::ITensor* out1;
34+
if (!keep_dims){
35+
if (topk_dims[dim] == 1) {
36+
auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0));
37+
squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim));
38+
TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n);
39+
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0));
40+
41+
auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
42+
squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
43+
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
44+
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0));
45+
46+
}
47+
} else {
48+
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
49+
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
50+
}
3251

3352
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
3453
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());

tests/core/conversion/converters/test_reduce.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,18 @@ TEST(Converters, ATenProdKeepDimsConvertsCorrectly) {
212212
test_body(graph, in);
213213
}
214214

215+
TEST(Converters, ATenMaxKeepDimsConvertsCorrectly) {
216+
const auto graph = R"IR(
217+
graph(%x : Tensor):
218+
%2 : int = prim::Constant[value=-1]()
219+
%3 : bool = prim::Constant[value=1]()
220+
%keep.1 : Tensor, %6 : Tensor = aten::max(%x, %2, %3)
221+
return (%keep.1, %6))IR";
222+
223+
auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
224+
test_body(graph, in);
225+
}
226+
215227
TEST(Converters, ATenMeanDimNegOneIndexConvertsCorrectly) {
216228
const auto graph = R"IR(
217229
graph(%0 : Tensor):

0 commit comments

Comments
 (0)