Skip to content

Commit 8ca7a22

Browse files
authored
Merge pull request #1088 from mfeliz-cruise/michael.feliz/fix_slice_and_unbind
Fix errors in unbind and list slice
2 parents 07238c8 + d73738c commit 8ca7a22

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

core/conversion/converters/impl/select.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
2222

2323
if (unbind) {
2424
axis = args[1].unwrapToInt();
25+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
26+
axis = axis < 0 ? axis + maxDim : axis;
2527
numOutputs = in->getDimensions().d[axis];
2628
sizes.insert(sizes.end(), numOutputs, 1);
2729
} else {

core/conversion/evaluators/aten.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ auto aten_registrations TORCHTRT_UNUSED =
181181
.evaluator({c10::Symbol::fromQualString("aten::slice"),
182182
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
183183
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
184-
int64_t start = args.at(n->input(1)).unwrapToInt();
184+
185+
int64_t start = 0;
186+
auto startIVal = args.at(n->input(1)).IValue();
187+
if(!startIVal->isNone()){
188+
start = args.at(n->input(1)).unwrapToInt();
189+
}
185190
int64_t end = args.at(n->input(2)).unwrapToInt();
186191
int64_t step = args.at(n->input(3)).unwrapToInt();
187192

tests/core/conversion/converters/test_select.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,38 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
365365
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
366366
}
367367

368+
TEST(Converters, ATenSliceListConvertsCorrectly) {
369+
const auto graph = R"IR(
370+
graph(%x : Tensor):
371+
%1 : NoneType = prim::Constant()
372+
%2 : int = prim::Constant[value=2]()
373+
%3 : int = prim::Constant[value=1]()
374+
%4 : int = prim::Constant[value=3]()
375+
%list : Tensor[] = aten::unbind(%x, %4)
376+
%slice : Tensor[] = aten::slice(%list, %1, %2, %3)
377+
%out.1 : Tensor, %out.2 : Tensor = prim::ListUnpack(%slice)
378+
return (%out.1, %out.2))IR";
379+
380+
auto g = std::make_shared<torch::jit::Graph>();
381+
382+
torch::jit::parseIR(graph, g.get());
383+
384+
auto in_x = at::randint(1, 10, {6, 5, 3, 3}, {at::kCUDA});
385+
386+
auto jit_in_x = at::clone(in_x);
387+
388+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
389+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in_x});
390+
391+
auto trt_in_x = at::clone(in_x);
392+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in_x});
393+
394+
for (size_t i = 0; i < jit_results.size(); i++) {
395+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
396+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
397+
}
398+
}
399+
368400
TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) {
369401
const auto graph = R"IR(
370402
graph(%x.1 : Tensor):
@@ -796,3 +828,30 @@ TEST(Converters, ATenUnbindConvertsCorrectly) {
796828
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
797829
}
798830
}
831+
832+
TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
833+
const auto graph = R"IR(
834+
graph(%x.1 : Tensor):
835+
%2 : int = prim::Constant[value=-1]()
836+
%3 : Tensor[] = aten::unbind(%x.1, %2)
837+
%o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%3)
838+
return (%o1.1, %o2.1))IR";
839+
840+
auto g = std::make_shared<torch::jit::Graph>();
841+
842+
torch::jit::parseIR(graph, g.get());
843+
844+
auto in = at::randint(1, 10, {5, 2}, {at::kCUDA});
845+
846+
auto jit_in = at::clone(in);
847+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
848+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
849+
850+
auto trt_in = at::clone(in);
851+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
852+
853+
for (size_t i = 0; i < jit_results.size(); i++) {
854+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
855+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
856+
}
857+
}

0 commit comments

Comments
 (0)