Skip to content

Commit abc29a2

Browse files
committed
feat(aten::softmax): Adding support for any neg index
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 658ee4f commit abc29a2

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

Diff for: core/conversion/converters/impl/softmax.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
2727

2828
int64_t dim = args[1].IValue()->toInt();
2929
LOG_DEBUG("Softmax original dim " << dim);
30-
if (dim == -1) {
31-
dim = shape.size() - 1;
30+
if (dim < 0) {
31+
dim = shape.size() + dim;
3232
}
3333
LOG_DEBUG("Softmax converted dim " << dim);
3434
auto softmax = ctx->net->addSoftMax(*in);

Diff for: tests/core/conversion/converters/test_softmax.cpp

+27-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) {
7777
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
7878
}
7979

80-
TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
80+
TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveOneIndex) {
8181
const auto graph = R"IR(
8282
graph(%0 : Tensor):
8383
%1 : None = prim::Constant()
@@ -100,5 +100,31 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
100100

101101
auto trt = trt_results[0].reshape_as(jit_results[0]);
102102

103+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
104+
}
105+
106+
TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
107+
const auto graph = R"IR(
108+
graph(%0 : Tensor):
109+
%1 : None = prim::Constant()
110+
%2 : int = prim::Constant[value=-2]()
111+
%3 : Tensor = aten::softmax(%0, %2, %1)
112+
return (%3))IR";
113+
114+
auto g = std::make_shared<torch::jit::Graph>();
115+
torch::jit::parseIR(graph, &*g);
116+
117+
auto in = at::randint(0, 5, {1, 2, 2, 2, 2}, {at::kCUDA});
118+
119+
auto jit_in = at::clone(in);
120+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
121+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
122+
123+
auto trt_in = at::clone(in);
124+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
125+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
126+
127+
auto trt = trt_results[0].reshape_as(jit_results[0]);
128+
103129
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
104130
}

0 commit comments

Comments
 (0)