Skip to content

Commit da6684e

Browse files
Fix output dimensions of unbind (#1373)
1 parent dd88afc commit da6684e

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

core/conversion/converters/impl/select.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
6161
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
6262
auto gather_out = gather_layer->getOutput(0);
6363

64+
if (unbind) { // unbind removes the split dimension
65+
auto squeeze_layer = ctx->net->addShuffle(*gather_out);
66+
squeeze_layer->setReshapeDimensions(util::squeezeDims(gather_out->getDimensions(), axis));
67+
TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n);
68+
gather_out = squeeze_layer->getOutput(0);
69+
}
70+
6471
auto tensor_holder = TensorContainer();
6572
tensor_holder.hold_tensor(gather_out);
6673
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));

tests/core/conversion/converters/test_select.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ TEST(Converters, ATenUnbindConvertsCorrectly) {
10061006
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
10071007

10081008
for (size_t i = 0; i < jit_results.size(); i++) {
1009-
auto trt = trt_results[i].reshape(jit_results[i].sizes());
1009+
auto trt = trt_results[i];
10101010
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
10111011
}
10121012
}
@@ -1033,7 +1033,7 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
10331033
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
10341034

10351035
for (size_t i = 0; i < jit_results.size(); i++) {
1036-
auto trt = trt_results[i].reshape(jit_results[i].sizes());
1036+
auto trt = trt_results[i];
10371037
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
10381038
}
10391039
}

0 commit comments

Comments
 (0)