@@ -30,17 +30,87 @@ TEST(Converters, ATenFlattenConvertsCorrectly) {
30
30
31
31
// TODO: IR Parser doesnt work well with neg numbers
32
32
TEST (Converters, ATenFlattenOtherDimsConvertsCorrectly) {
33
- const auto graph = R"IR(
34
- graph(%0 : Tensor):
35
- %1 : int = prim::Constant[value=1]()
36
- %2 : int = prim::Constant[value=2]()
37
- %3 : Tensor = aten::flatten(%0, %1, %2)
38
- return (%3))IR" ;
33
+ const auto graph = R"IR(
34
+ graph(%0 : Tensor):
35
+ %1 : int = prim::Constant[value=1]()
36
+ %2 : int = prim::Constant[value=2]()
37
+ %3 : Tensor = aten::flatten(%0, %1, %2)
38
+ return (%3))IR" ;
39
+
40
+ auto g = std::make_shared<torch::jit::Graph>();
41
+ torch::jit::parseIR (graph, &*g);
42
+
43
+ auto in = at::randint (0 , 5 , {2 , 3 , 3 }, {at::kCUDA });
44
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
45
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
46
+
47
+ in = at::clone (in);
48
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
49
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
50
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
51
+
52
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
53
+ }
39
54
40
- auto g = std::make_shared<torch::jit::Graph>();
55
+ TEST (Converters, ATenReshapeConvertsCorrectly) {
56
+ const auto graph = R"IR(
57
+ graph(%0 : Tensor):
58
+ %1 : int = prim::Constant[value=3]()
59
+ %2 : int = prim::Constant[value=2]()
60
+ %3 : int[] = prim::ListConstruct(%1, %2)
61
+ %4 : Tensor = aten::reshape(%0, %3)
62
+ return (%4))IR" ;
63
+
64
+ auto g = std::make_shared<torch::jit::Graph>();
65
+ torch::jit::parseIR (graph, &*g);
66
+
67
+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
68
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
69
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
70
+
71
+ in = at::clone (in);
72
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
73
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
74
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
75
+
76
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
77
+ }
78
+
79
+ TEST (Converters, ATenViewConvertsCorrectly) {
80
+ const auto graph = R"IR(
81
+ graph(%0 : Tensor):
82
+ %1 : int = prim::Constant[value=3]()
83
+ %2 : int = prim::Constant[value=2]()
84
+ %3 : int[] = prim::ListConstruct(%1, %2)
85
+ %4 : Tensor = aten::view(%0, %3)
86
+ return (%4))IR" ;
87
+
88
+ auto g = std::make_shared<torch::jit::Graph>();
89
+ torch::jit::parseIR (graph, &*g);
90
+
91
+ auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
92
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
93
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
94
+
95
+ in = at::clone (in);
96
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
97
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
98
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
99
+
100
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
101
+ }
102
+
103
+ TEST (Converters, ATenPermuteConvertsCorrectly) {
104
+ const auto graph = R"IR(
105
+ graph(%x.1 : Tensor):
106
+ %2 : int[] = prim::Constant[value=[3, 0, 1, 2]]()
107
+ %3 : Tensor = aten::permute(%x.1, %2)
108
+ return (%3))IR" ;
109
+
110
+ auto g = std::make_shared<torch::jit::Graph>();
41
111
torch::jit::parseIR (graph, &*g);
42
112
43
- auto in = at::randint (0 , 5 , {2 , 3 , 3 }, {at::kCUDA });
113
+ auto in = at::randint (0 , 5 , {2 , 3 , 2 , 3 }, {at::kCUDA });
44
114
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
45
115
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
46
116
@@ -52,19 +122,17 @@ TEST(Converters, ATenFlattenOtherDimsConvertsCorrectly) {
52
122
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
53
123
}
54
124
55
- TEST (Converters, ATenReshapeConvertsCorrectly) {
56
- const auto graph = R"IR(
57
- graph(%0 : Tensor):
58
- %1 : int = prim::Constant[value=3]()
59
- %2 : int = prim::Constant[value=2]()
60
- %3 : int[] = prim::ListConstruct(%1, %2)
61
- %4 : Tensor = aten::reshape(%0, %3)
62
- return (%4))IR" ;
125
+ TEST (Converters, ATenPermute3DConvertsCorrectly) {
126
+ const auto graph = R"IR(
127
+ graph(%x.1 : Tensor):
128
+ %2 : int[] = prim::Constant[value=[0, 2, 1]]()
129
+ %3 : Tensor = aten::permute(%x.1, %2)
130
+ return (%3))IR" ;
63
131
64
- auto g = std::make_shared<torch::jit::Graph>();
132
+ auto g = std::make_shared<torch::jit::Graph>();
65
133
torch::jit::parseIR (graph, &*g);
66
134
67
- auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
135
+ auto in = at::randint (0 , 5 , {2 , 2 , 3 }, {at::kCUDA });
68
136
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
69
137
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
70
138
@@ -76,19 +144,17 @@ TEST(Converters, ATenReshapeConvertsCorrectly) {
76
144
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
77
145
}
78
146
79
- TEST (Converters, ATenViewConvertsCorrectly) {
80
- const auto graph = R"IR(
81
- graph(%0 : Tensor):
82
- %1 : int = prim::Constant[value=3]()
83
- %2 : int = prim::Constant[value=2]()
84
- %3 : int[] = prim::ListConstruct(%1, %2)
85
- %4 : Tensor = aten::view(%0, %3)
86
- return (%4))IR" ;
147
+ TEST (Converters, ATenPermute5DConvertsCorrectly) {
148
+ const auto graph = R"IR(
149
+ graph(%x.1 : Tensor):
150
+ %2 : int[] = prim::Constant[value=[3, 4, 0, 2, 1]]()
151
+ %3 : Tensor = aten::permute(%x.1, %2)
152
+ return (%3))IR" ;
87
153
88
- auto g = std::make_shared<torch::jit::Graph>();
154
+ auto g = std::make_shared<torch::jit::Graph>();
89
155
torch::jit::parseIR (graph, &*g);
90
156
91
- auto in = at::randint (0 , 5 , {2 , 3 }, {at::kCUDA });
157
+ auto in = at::randint (0 , 5 , {2 , 2 , 1 , 2 , 3 }, {at::kCUDA });
92
158
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
93
159
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
94
160
0 commit comments