Skip to content

Commit f594e43

Browse files
committed
test(aten::select): added test case for single call to select
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 37c68d3 commit f594e43

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

Diff for: tests/core/converters/test_select.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,32 @@
44
#include "tests/util/util.h"
55
#include "core/compiler.h"
66

7+
TEST(Converters, ATenSelectIntConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%2 : int = prim::Constant[value=0]()
11+
%3 : Tensor = aten::select(%0, %2, %2)
12+
return (%3))IR";
13+
14+
auto g = std::make_shared<torch::jit::Graph>();
15+
16+
torch::jit::parseIR(graph, &*g);
17+
18+
auto in = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
19+
20+
auto jit_in = at::clone(in);
21+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
22+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
23+
24+
auto trt_in = at::clone(in);
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
27+
28+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
29+
30+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
31+
}
32+
733
TEST(Converters, ATenSelectIntTwiceConvertsCorrectly) {
834
const auto graph = R"IR(
935
graph(%0 : Tensor):

0 commit comments

Comments
 (0)