diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 58f0aeae61..f4cca7ba19 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -362,7 +362,7 @@ auto select_registrations TORCHTRT_UNUSED = nvinfer1::ElementWiseOperation::kPROD, d0, dim_tensor, - std::string("compute_dim0_") + std::to_string(i)) + util::node_info(n) + std::string("_compute_dim0_") + std::to_string(i)) ->getOutput(0); } @@ -378,7 +378,7 @@ auto select_registrations TORCHTRT_UNUSED = nvinfer1::ElementWiseOperation::kPROD, d1, dim_tensor, - std::string("compute_dim1_") + std::to_string(i)) + util::node_info(n) + std::string("_compute_dim1_") + std::to_string(i)) ->getOutput(0); } @@ -398,26 +398,27 @@ auto select_registrations TORCHTRT_UNUSED = nvinfer1::ITensor* multiplier = dim_tensor_list[adv_idx_indices[adv_idx_count - 1]]; nvinfer1::ITensor* cum_adv_index = tensors[adv_idx_count - 1]; for (int i = adv_idx_count - 2; i >= 0; i--) { - nvinfer1::ITensor* adv_index = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPROD, - tensors[i], - multiplier, - std::string("adv_index_") + std::to_string(i)) - ->getOutput(0); + nvinfer1::ITensor* adv_index = + add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + tensors[i], + multiplier, + util::node_info(n) + std::string("_adv_index_") + std::to_string(i)) + ->getOutput(0); cum_adv_index = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kSUM, cum_adv_index, adv_index, - std::string("cum_adv_index_") + std::to_string(i)) + util::node_info(n) + std::string("_cum_adv_index_") + std::to_string(i)) ->getOutput(0); multiplier = add_elementwise( ctx, nvinfer1::ElementWiseOperation::kPROD, multiplier, dim_tensor_list[adv_idx_indices[i]], - std::string("multiplier_") + std::to_string(i)) + util::node_info(n) + std::string("_multiplier_") + std::to_string(i)) ->getOutput(0); } diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 1285c24dd6..9d17573cd0 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -833,6 +833,38 @@ TEST(Converters, ATenIndexTensorFullIndicesConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenIndexTensorRepeatedFullIndicesConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor, + %index0 : Tensor, + %index1 : Tensor, + %index2 : Tensor): + %18 : Tensor?[] = prim::ListConstruct(%index0, %index1, %index2) + %19 : Tensor = aten::index(%x.1, %18) + %20 : Tensor = aten::index(%x.1, %18) + return (%19, %20))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in1 = at::randint(1, 10, {5, 10, 4}, {at::kCUDA}); + auto index0 = at::tensor({0, 1, 2, 3}, {at::kCUDA}).to(torch::kLong); + auto index1 = at::tensor({1, 3, 4, 6}, {at::kCUDA}).to(torch::kLong); + auto index2 = at::tensor({3, 2, 1, 0}, {at::kCUDA}).to(torch::kLong); + auto index0_trt = index0.to(torch::kInt32); + auto index1_trt = index1.to(torch::kInt32); + auto index2_trt = index2.to(torch::kInt32); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, index0, index1, index2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, index0_trt, index1_trt, index2_trt}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6)); +} + TEST(Converters, ATenIndexTensorIdx0Idx1NoneConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor,