@@ -77,7 +77,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) {
77
77
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
78
78
}
79
79
80
- TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex ) {
80
+ TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveOneIndex ) {
81
81
const auto graph = R"IR(
82
82
graph(%0 : Tensor):
83
83
%1 : None = prim::Constant()
@@ -100,5 +100,31 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
100
100
101
101
auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
102
102
103
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
104
+ }
105
+
106
+ TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
107
+ const auto graph = R"IR(
108
+ graph(%0 : Tensor):
109
+ %1 : None = prim::Constant()
110
+ %2 : int = prim::Constant[value=-2]()
111
+ %3 : Tensor = aten::softmax(%0, %2, %1)
112
+ return (%3))IR" ;
113
+
114
+ auto g = std::make_shared<torch::jit::Graph>();
115
+ torch::jit::parseIR (graph, &*g);
116
+
117
+ auto in = at::randint (0 , 5 , {1 , 2 , 2 , 2 , 2 }, {at::kCUDA });
118
+
119
+ auto jit_in = at::clone (in);
120
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
121
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
122
+
123
+ auto trt_in = at::clone (in);
124
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
125
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
126
+
127
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
128
+
103
129
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
104
130
}
0 commit comments