diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 58f0aeae61..e57fdd93d2 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -61,6 +61,13 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s auto gather_layer = ctx->net->addGather(*in, *indicesTensor, axis); auto gather_out = gather_layer->getOutput(0); + if (unbind) { // unbind removes the split dimension + auto squeeze_layer = ctx->net->addShuffle(*gather_out); + squeeze_layer->setReshapeDimensions(util::squeezeDims(gather_out->getDimensions(), axis)); + TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n); + gather_out = squeeze_layer->getOutput(0); + } + auto tensor_holder = TensorContainer(); tensor_holder.hold_tensor(gather_out); auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 1285c24dd6..2bc51735a2 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -974,7 +974,7 @@ TEST(Converters, ATenUnbindConvertsCorrectly) { auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); + auto trt = trt_results[i]; ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } } @@ -1001,7 +1001,7 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); for (size_t i = 0; i < jit_results.size(); i++) { - auto trt = trt_results[i].reshape(jit_results[i].sizes()); + auto trt = trt_results[i]; ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } }