Skip to content

Commit 7622a97

Browse files
committed
fix(//core/conversion/converters/impl/reduce): Adds support for multiple
reduction dimensions Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 377ad67 commit 7622a97

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

Diff for: core/conversion/converters/impl/reduce.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <bitset>
12
#include "core/util/prelude.h"
23
#include "core/conversion/converters/converters.h"
34

@@ -22,25 +23,36 @@ auto reduced_registrations = RegisterNodeConversionPatterns()
2223
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);
2324

2425
mean_layer->setName(util::node_info(n).c_str());
25-
ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
26+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
27+
28+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
2629
return true;
2730
}
2831
}).pattern({
29-
"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
32+
"aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)",
3033
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
3134
auto in_tensor = args[0].ITensor();
32-
auto dim = args[1].unwrapToIntList();
33-
auto keepdim = args[2].unwrapToBool();
35+
auto dims = args[1].unwrapToIntList();
36+
LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info
3437

35-
uint32_t axis_mask = 1 << dim[0];
38+
uint32_t axis_mask = 0;
39+
for (int d = 0; d < dims.size(); d++) {
40+
axis_mask |= 1 << dims[d];
41+
}
42+
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
43+
44+
auto keepdim = args[2].unwrapToBool();
45+
LOG_DEBUG("Keep dims :" << keepdim);
3646

3747
LOG_WARNING("Mean converter disregards dtype");
3848
auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, keepdim);
3949

4050
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);
4151

4252
mean_layer->setName(util::node_info(n).c_str());
43-
ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
53+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mean_layer->getOutput(0));
54+
55+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
4456
return true;
4557
}
4658
});

Diff for: tests/core/converters/test_reduce.cpp

+26-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,32 @@ TEST(Converters, ATenMeanRowConvertsCorrectly) {
5959
auto g = std::make_shared<torch::jit::Graph>();
6060
torch::jit::script::parseIR(graph, &*g);
6161

62-
auto in = at::randint(-5, 5, {4, 4}, at::kCUDA);
62+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
63+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
64+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
65+
66+
in = at::clone(in);
67+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
68+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
69+
70+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0]));
71+
}
72+
73+
TEST(Converters, ATenMeanMultiDimsConvertsCorrectly) {
74+
const auto graph = R"IR(
75+
graph(%0 : Tensor):
76+
%1 : int = prim::Constant[value=0]()
77+
%2 : int = prim::Constant[value=1]()
78+
%3 : int[] = prim::ListConstruct(%1, %2)
79+
%4 : bool = prim::Constant[value=0]()
80+
%5 : None = prim::Constant()
81+
%6 : Tensor = aten::mean(%0, %3, %4, %5)
82+
return (%6))IR";
83+
84+
auto g = std::make_shared<torch::jit::Graph>();
85+
torch::jit::script::parseIR(graph, &*g);
86+
87+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
6388
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
6489
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
6590

0 commit comments

Comments
 (0)