4
4
#include " tests/util/util.h"
5
5
#include " core/compiler.h"
6
6
7
- void pointwise_test_helper (std::string graph_ir) {
7
+ void pointwise_test_helper (std::string graph_ir, bool singleInput ) {
8
8
auto g = std::make_shared<torch::jit::Graph>();
9
9
torch::jit::parseIR (graph_ir, &*g);
10
-
11
- auto in0 = at::randint (1 , 5 , {5 }, {at::kCUDA });
12
- auto in1 = at::randint (1 , 5 , {5 }, {at::kCUDA });
10
+
11
+ // singleInput case is enabled when elementwise operation is performed
12
+ // with an input and a constant embedded in graph
13
+ std::vector<at::Tensor> torch_inputs;
14
+ torch_inputs.push_back (at::randint (1 , 5 , {5 }, {at::kCUDA }));
15
+ if (!singleInput) {
16
+ torch_inputs.push_back (at::randint (1 , 5 , {5 }, {at::kCUDA }));
17
+ }
13
18
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
14
- auto jit_results = trtorch::tests::util::RunGraph (g, params, {in0, in1});
19
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, torch_inputs);
20
+
21
+ std::vector<at::Tensor> trt_inputs;
22
+ for (auto in : torch_inputs) {
23
+ trt_inputs.push_back (at::clone (in));
24
+ }
15
25
16
- in0 = at::clone (in0);
17
- in1 = at::clone (in1);
18
26
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
19
- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in0, in1} );
27
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, trt_inputs );
20
28
21
29
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
22
30
}
23
31
24
32
25
-
26
33
TEST (Converters, ATenAddConvertsCorrectly) {
27
34
const auto graph = R"IR(
28
35
graph(%0 : Tensor, %1 : Tensor):
29
36
%2 : int = prim::Constant[value=1]()
30
37
%3 : Tensor = aten::add(%0, %1, %2)
31
38
return (%3))IR" ;
32
- pointwise_test_helper (graph);
39
+ pointwise_test_helper (graph, false );
33
40
}
34
41
35
42
@@ -39,7 +46,7 @@ TEST(Converters, ATenAddConvertsCorrectly) {
39
46
// %2 : int = prim::Constant[value=2]()
40
47
// %3 : Tensor = aten::add(%0, %1, %2)
41
48
// return (%3))IR";
42
- // pointwise_test_helper(graph);
49
+ // pointwise_test_helper(graph, false );
43
50
// }
44
51
45
52
TEST (Converters, ATenSubConvertsCorrectly) {
@@ -48,7 +55,7 @@ TEST(Converters, ATenSubConvertsCorrectly) {
48
55
%2 : int = prim::Constant[value=1]()
49
56
%3 : Tensor = aten::sub(%0, %1, %2)
50
57
return (%3))IR" ;
51
- pointwise_test_helper (graph);
58
+ pointwise_test_helper (graph, false );
52
59
}
53
60
54
61
// TEST(Converters, ATenSubWithScaleConvertsCorrectly) {
@@ -57,21 +64,38 @@ TEST(Converters, ATenSubConvertsCorrectly) {
57
64
// %2 : float = prim::Constant[value=0.5]()
58
65
// %3 : Tensor = aten::add(%0, %1, %2)
59
66
// return (%3))IR";
60
- // pointwise_test_helper(graph);
67
+ // pointwise_test_helper(graph, false );
61
68
// }
62
69
63
70
TEST (Converters, ATenMulConvertsCorrectly) {
64
71
const auto graph = R"IR(
65
72
graph(%0 : Tensor, %1 : Tensor):
66
73
%2 : Tensor = aten::mul(%0, %1)
67
74
return (%2))IR" ;
68
- pointwise_test_helper (graph);
75
+ pointwise_test_helper (graph, false );
69
76
}
70
77
71
78
TEST (Converters, ATenDivConvertsCorrectly) {
72
79
const auto graph = R"IR(
73
80
graph(%0 : Tensor, %1 : Tensor):
74
81
%2 : Tensor = aten::div(%0, %1)
75
82
return (%2))IR" ;
76
- pointwise_test_helper (graph);
83
+ pointwise_test_helper (graph, false );
84
+ }
85
+
86
+ TEST (Converters, ATenPowTensorConvertsCorrectly) {
87
+ const auto graph = R"IR(
88
+ graph(%x.1 : Tensor, %x2.1 : Tensor):
89
+ %3 : Tensor = aten::pow(%x.1, %x2.1)
90
+ return (%3))IR" ;
91
+ pointwise_test_helper (graph, false );
92
+ }
93
+
94
+ TEST (Converters, ATenPowScalarConvertsCorrectly) {
95
+ const auto graph = R"IR(
96
+ graph(%x.1 : Tensor):
97
+ %2 : int = prim::Constant[value=2]()
98
+ %3 : Tensor = aten::pow(%x.1, %2)
99
+ return (%3))IR" ;
100
+ pointwise_test_helper (graph, true );
77
101
}
0 commit comments