@@ -42,7 +42,7 @@ TEST(CppAPITests, TestCollectionStandardTensorInput) {
42
42
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
43
43
auto trt_out = trt_mod.forward (inputs_);
44
44
45
- ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor (), 0.99 ));
45
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor ()));
46
46
}
47
47
48
48
TEST (CppAPITests, TestCollectionTupleInput) {
@@ -85,7 +85,7 @@ TEST(CppAPITests, TestCollectionTupleInput) {
85
85
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
86
86
auto trt_out = trt_mod.forward (complex_inputs);
87
87
88
- ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor (), 0.99 ));
88
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor ()));
89
89
}
90
90
91
91
TEST (CppAPITests, TestCollectionListInput) {
@@ -144,7 +144,7 @@ TEST(CppAPITests, TestCollectionListInput) {
144
144
LOG_DEBUG (" Finish compile" );
145
145
auto trt_out = trt_mod.forward (complex_inputs);
146
146
147
- ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor (), 0.99 ));
147
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (out.toTensor (), trt_out.toTensor ()));
148
148
}
149
149
150
150
TEST (CppAPITests, TestCollectionTupleInputOutput) {
@@ -178,23 +178,20 @@ TEST(CppAPITests, TestCollectionTupleInputOutput) {
178
178
torch::jit::IValue complex_input_shape (input_shape_tuple);
179
179
std::tuple<torch::jit::IValue> input_tuple2 (complex_input_shape);
180
180
torch::jit::IValue complex_input_shape2 (input_tuple2);
181
- // torch::jit::IValue complex_input_shape(list);
182
181
183
182
auto compile_settings = torch_tensorrt::ts::CompileSpec (complex_input_shape2);
184
183
compile_settings.min_block_size = 1 ;
185
184
186
- // compile_settings.torch_executed_ops.push_back("prim::TupleConstruct");
187
-
188
185
// // FP16 execution
189
186
compile_settings.enabled_precisions = {torch::kHalf };
190
187
// // Compile module
191
188
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
192
189
auto trt_out = trt_mod.forward (complex_inputs);
193
190
194
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
195
- out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-5 ));
196
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
197
- out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-5 ));
191
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (
192
+ out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor ()));
193
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (
194
+ out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor ()));
198
195
}
199
196
200
197
TEST (CppAPITests, TestCollectionListInputOutput) {
@@ -252,10 +249,10 @@ TEST(CppAPITests, TestCollectionListInputOutput) {
252
249
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
253
250
auto trt_out = trt_mod.forward (complex_inputs);
254
251
255
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
256
- out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor (), 1e-5 ));
257
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
258
- out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor (), 1e-5 ));
252
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (
253
+ out.toList ().vec ()[0 ].toTensor (), trt_out.toList ().vec ()[0 ].toTensor ()));
254
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (
255
+ out.toList ().vec ()[1 ].toTensor (), trt_out.toList ().vec ()[1 ].toTensor ()));
259
256
}
260
257
261
258
TEST (CppAPITests, TestCollectionComplexModel) {
@@ -313,8 +310,8 @@ TEST(CppAPITests, TestCollectionComplexModel) {
313
310
auto trt_mod = torch_tensorrt::torchscript::compile (mod, compile_settings);
314
311
auto trt_out = trt_mod.forward (complex_inputs);
315
312
316
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
317
- out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor (), 1e-5 ));
318
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
319
- out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor (), 1e-5 ));
313
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (
314
+ out.toTuple ()->elements ()[0 ].toTensor (), trt_out.toTuple ()->elements ()[0 ].toTensor ()));
315
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (
316
+ out.toTuple ()->elements ()[1 ].toTensor (), trt_out.toTuple ()->elements ()[1 ].toTensor ()));
320
317
}
0 commit comments