Skip to content

Commit f6ba547

Browse files
authored
fix: Correct aten::split behavior when negative indexing is applied (#1403)
- Identify issue arising in the deepset/roberta-base-squad2 HuggingFace network (https://huggingface.co/deepset/roberta-base-squad2) during conversion - Bug involves the use of a negative index (-1) when applying the aten::split operator - Improved axis handling in add_split to account for negative indexing - Improved error handling in add_split to ensure valid axes for C++ indexing
1 parent 5a7ac8f commit f6ba547

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

core/conversion/converters/impl/select.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,23 @@ namespace {
1717

1818
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind) {
1919
auto in = args[0].ITensor();
20-
auto numOutputs = 1, numRemainder = 0, axis = 0;
20+
auto numOutputs = 1, numRemainder = 0;
2121
std::vector<int64_t> sizes;
2222

23+
// Precompute axis along which to apply split, ensuring negative dimensions are re-indexed
24+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
25+
auto input_axis = unbind ? args[1].unwrapToInt() : args[2].unwrapToInt();
26+
auto axis = input_axis < 0 ? input_axis + maxDim : input_axis;
27+
28+
// Ensure input axis is valid for input tensor
29+
TORCHTRT_CHECK(
30+
(axis >= 0) && (axis < maxDim),
31+
"Expected input axis to fall in range [-" << maxDim << ", " << (maxDim - 1) << "], got " << input_axis);
32+
2333
if (unbind) {
24-
axis = args[1].unwrapToInt();
25-
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
26-
axis = axis < 0 ? axis + maxDim : axis;
2734
numOutputs = in->getDimensions().d[axis];
2835
sizes.insert(sizes.end(), numOutputs, 1);
2936
} else {
30-
axis = args[2].unwrapToInt();
3137
auto inDimSize = in->getDimensions().d[axis];
3238
if (split_list) {
3339
sizes = args[1].unwrapToIntList().vec();
@@ -274,7 +280,8 @@ auto select_registrations TORCHTRT_UNUSED =
274280
.pattern(
275281
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
276282
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
277-
// refer to https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py#L4627
283+
// refer to
284+
// https://github.com/pytorch/pytorch/blob/974ad8fa6cc63b89234beb5ebff54c2d42711932/torch/onnx/symbolic_opset9.py#L4627
278285
auto in = args[0].ITensorOrFreeze(ctx);
279286
auto ts = args[1].IValue()->toListRef();
280287

tests/core/conversion/converters/test_select.cpp

+29-1
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,34 @@ TEST(Converters, ATenSplitAndAddConvertsCorrectly) {
739739
}
740740
}
741741

742+
TEST(Converters, ATenSplitNegativeDimsConvertsCorrectly) {
743+
const auto graph = R"IR(
744+
graph(%x.1 : Tensor):
745+
%2 : int = prim::Constant[value=1]()
746+
%n1 : int = prim::Constant[value=-1]()
747+
%3 : Tensor[] = aten::split(%x.1, %2, %n1)
748+
%4 : Tensor, %5 : Tensor, %6 : Tensor, %7 : Tensor = prim::ListUnpack(%3)
749+
return (%4, %5, %6, %7))IR";
750+
751+
auto g = std::make_shared<torch::jit::Graph>();
752+
753+
torch::jit::parseIR(graph, g.get());
754+
755+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
756+
757+
auto jit_in = at::clone(in);
758+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
759+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
760+
761+
auto trt_in = at::clone(in);
762+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
763+
764+
for (size_t i = 0; i < jit_results.size(); i++) {
765+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
766+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
767+
}
768+
}
769+
742770
TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
743771
const auto graph = R"IR(
744772
graph(%x.1 : Tensor):
@@ -1109,4 +1137,4 @@ TEST(Converters, ScatterSrcConvertsCorrectly) {
11091137
auto trt = trt_results[i].reshape(jit_results[i].sizes());
11101138
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
11111139
}
1112-
}
1140+
}

0 commit comments

Comments
 (0)