Skip to content

Commit 11c4608

Browse files
committed
feat(aten::__range_length): Adding range length evaluator
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4fd886d commit 11c4608

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,25 @@ auto aten_registrations TORCHTRT_UNUSED =
718718
torch::jit::pop(stack, output);
719719
return output;
720720
},
721-
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})});
721+
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})})
722+
.evaluator({c10::Symbol::fromQualString("aten::__range_length"),
723+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
724+
auto lo = args.at(n->input(0)).unwrapToInt();
725+
auto hi = args.at(n->input(1)).unwrapToInt();
726+
auto step = args.at(n->input(2)).unwrapToInt();
727+
728+
if (step == 0) {
729+
TORCHTRT_THROW_ERROR("aten::__range_length() arg 3 must not be zero");
730+
}
731+
if (step > 0 && lo < hi) {
732+
return 1 + (hi - 1 - lo) / step;
733+
} else if (step < 0 && lo > hi) {
734+
return 1 + (lo - 1 - hi) / (0 - step);
735+
} else {
736+
return 0;
737+
}
738+
},
739+
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})});
722740
} // namespace
723741
} // namespace evaluators
724742
} // namespace conversion

Diff for: tests/core/conversion/evaluators/test_aten_evaluators.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -665,4 +665,40 @@ TEST(Evaluators, AtenFormatRaiseExceptionEvaluatesCorrectly) {
665665
} else {
666666
ASSERT_TRUE(false);
667667
}
668+
}
669+
670+
TEST(Evaluators, RangeLengthEvaluatesCorrectly) {
671+
const auto graph = R"IR(
672+
graph():
673+
%1 : int = prim::Constant[value=1]()
674+
%2 : int = prim::Constant[value=10]()
675+
%3 : int = prim::Constant[value=2]()
676+
%4 : int = aten::__range_length(%1, %2, %3)
677+
return (%3))IR";
678+
679+
auto g = std::make_shared<torch::jit::Graph>();
680+
torch::jit::parseIR(graph, g.get());
681+
682+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
683+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
684+
685+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
686+
}
687+
688+
TEST(Evaluators, RangeLengthNegEvaluatesCorrectly) {
689+
const auto graph = R"IR(
690+
graph():
691+
%1 : int = prim::Constant[value=10]()
692+
%2 : int = prim::Constant[value=1]()
693+
%3 : int = prim::Constant[value=-2]()
694+
%4 : int = aten::__range_length(%1, %2, %3)
695+
return (%3))IR";
696+
697+
auto g = std::make_shared<torch::jit::Graph>();
698+
torch::jit::parseIR(graph, g.get());
699+
700+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
701+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
702+
703+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
668704
}

0 commit comments

Comments
 (0)