@@ -365,6 +365,38 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
365
365
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
366
366
}
367
367
368
+ TEST (Converters, ATenSliceListConvertsCorrectly) {
369
+ const auto graph = R"IR(
370
+ graph(%x : Tensor):
371
+ %1 : NoneType = prim::Constant()
372
+ %2 : int = prim::Constant[value=2]()
373
+ %3 : int = prim::Constant[value=1]()
374
+ %4 : int = prim::Constant[value=3]()
375
+ %list : Tensor[] = aten::unbind(%x, %4)
376
+ %slice : Tensor[] = aten::slice(%list, %1, %2, %3)
377
+ %out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice)
378
+ return (%out.1, %out.2))IR" ;
379
+
380
+ auto g = std::make_shared<torch::jit::Graph>();
381
+
382
+ torch::jit::parseIR (graph, g.get ());
383
+
384
+ auto in_x = at::randint (1 , 10 , {6 , 5 , 3 , 3 }, {at::kCUDA });
385
+
386
+ auto jit_in_x = at::clone (in_x);
387
+
388
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
389
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in_x});
390
+
391
+ auto trt_in_x = at::clone (in_x);
392
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in_x});
393
+
394
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
395
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
396
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
397
+ }
398
+ }
399
+
368
400
TEST (Converters, ATenSliceDynamicBatchConvertsCorrectly) {
369
401
const auto graph = R"IR(
370
402
graph(%x.1 : Tensor):
@@ -796,3 +828,30 @@ TEST(Converters, ATenUnbindConvertsCorrectly) {
796
828
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
797
829
}
798
830
}
831
+
832
+ TEST (Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
833
+ const auto graph = R"IR(
834
+ graph(%x.1 : Tensor):
835
+ %2 : int = prim::Constant[value=-1]()
836
+ %3 : Tensor[] = aten::unbind(%x.1, %2)
837
+ %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
838
+ return (%o1.1, %o2.1))IR" ;
839
+
840
+ auto g = std::make_shared<torch::jit::Graph>();
841
+
842
+ torch::jit::parseIR (graph, g.get ());
843
+
844
+ auto in = at::randint (1 , 10 , {5 , 2 }, {at::kCUDA });
845
+
846
+ auto jit_in = at::clone (in);
847
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
848
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in});
849
+
850
+ auto trt_in = at::clone (in);
851
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in});
852
+
853
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
854
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
855
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
856
+ }
857
+ }
0 commit comments