Skip to content

Commit 4494699

Browse files
authored
Fixing aten::slice invalid schema and implementing aten::list evaluator (#1695)
1 parent 6a26856 commit 4494699

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

core/conversion/evaluators/aten.cpp

+20-6
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,20 @@ auto aten_registrations TORCHTRT_UNUSED =
223223
{c10::Symbol::fromQualString("aten::slice"),
224224
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
225225
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
226-
227226
int64_t start = 0;
227+
int64_t end = 9223372036854775807;
228228
auto startIVal = args.at(n->input(1)).IValue();
229+
auto endIVal = args.at(n->input(2)).IValue();
230+
229231
if (!startIVal->isNone()) {
230232
start = args.at(n->input(1)).unwrapToInt();
231233
}
232-
int64_t end = args.at(n->input(2)).unwrapToInt();
234+
if (!endIVal->isNone()) {
235+
end = args.at(n->input(2)).unwrapToInt();
236+
}
237+
if (start > end) {
238+
LOG_DEBUG("The end should be greater than start");
239+
}
233240
int64_t step = args.at(n->input(3)).unwrapToInt();
234241

235242
const int64_t list_size = list.size();
@@ -253,8 +260,9 @@ auto aten_registrations TORCHTRT_UNUSED =
253260

254261
return sliced_list;
255262
},
256-
EvalOptions().validSchemas(
257-
{"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
263+
EvalOptions().validSchemas({"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])"})})
264+
// EvalOptions().validSchemas(
265+
// {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
258266
.evaluator(
259267
{c10::Symbol::fromQualString("aten::len"),
260268
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -896,8 +904,14 @@ auto aten_registrations TORCHTRT_UNUSED =
896904
auto step = args.at(n->input(2)).unwrapToInt();
897905
return start + idx * step;
898906
},
899-
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})});
900-
907+
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})})
908+
.evaluator(
909+
{c10::Symbol::fromQualString("aten::list"),
910+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
911+
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
912+
return list.copy();
913+
},
914+
EvalOptions().validSchemas({"aten::list.t(t[] l) -> (t[])"})});
901915
} // namespace
902916
} // namespace evaluators
903917
} // namespace conversion

tests/core/conversion/evaluators/evaluator_test.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def evaluator_test(name, visibility = None):
2222
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2323
"//conditions:default": ["@libtorch//:libtorch"],
2424
}),
25-
timeout = "short",
25+
timeout = "long",
2626
)

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -931,3 +931,38 @@ TEST(Evaluators, IsNotTrueEvaluatesCorrectly) {
931931

932932
ASSERT_TRUE(jit_results[0] == trt_results[0]);
933933
}
934+
935+
TEST(Evaluators, IsAtenSliceEvaluateCorrectly) {
936+
const auto graph = R"IR(
937+
graph():
938+
%1 : int[] = prim::Constant[value= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]()
939+
%2 : int = prim::Constant[value = 0]()
940+
%3 : int = prim::Constant[value = 7]()
941+
%4 : int = prim::Constant[value = 2]()
942+
%5 : int[] = aten::slice(%1, %2, %3, %4)
943+
return (%5))IR";
944+
945+
auto g = std::make_shared<torch::jit::Graph>();
946+
torch::jit::parseIR(graph, g.get());
947+
948+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
949+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
950+
951+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
952+
}
953+
954+
TEST(Evaluators, IsAtenListEvaluateCorrectly) {
955+
const auto graph = R"IR(
956+
graph():
957+
%1 : int[] = prim::Constant[value= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]()
958+
%2 : int[] = aten::list(%1)
959+
return (%2))IR";
960+
961+
auto g = std::make_shared<torch::jit::Graph>();
962+
torch::jit::parseIR(graph, g.get());
963+
964+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
965+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
966+
967+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
968+
}

0 commit comments

Comments
 (0)