Skip to content

Commit 20e5d41

Browse files
committed
fix(//core/conversion/evaluators): Change how schemas are handled
in aten::range evaluator Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 3b1ce7c commit 20e5d41

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -620,22 +620,19 @@ auto aten_registrations TORCHTRT_UNUSED =
620620
{"aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)"})})
621621
.evaluator({c10::Symbol::fromQualString("aten::arange"),
622622
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
623-
int input_size = n->inputs().size();
624-
int scalar_count = 0;
625-
for (int i = 0; i < input_size; i++) {
626-
if (args.at(n->input(i)).IValue()->isScalar()) {
627-
scalar_count += 1;
628-
}
629-
}
630-
if (scalar_count == 1) {
623+
auto schema = n->maybeSchema();
624+
TORCHTRT_CHECK(schema, "Unable to get schema for node: " << *n);
625+
auto name = schema->operator_name();
626+
627+
if (c10::toString(name) == "aten::arange") {
631628
if (args.at(n->input(0)).IValue()->isInt()) {
632629
int end_scalar = args.at(n->input(0)).unwrapToInt();
633630
return torch::arange(end_scalar);
634631
} else if (args.at(n->input(0)).IValue()->isDouble()) {
635632
float end_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
636633
return torch::arange(end_scalar);
637634
}
638-
} else if (scalar_count == 2) {
635+
} else if (c10::toString(name) == "aten::arange.start") {
639636
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble()) {
640637
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
641638
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
@@ -645,7 +642,7 @@ auto aten_registrations TORCHTRT_UNUSED =
645642
int end_scalar = args.at(n->input(1)).unwrapToInt();
646643
return torch::arange(start_scalar, end_scalar);
647644
}
648-
} else if (scalar_count == 3) {
645+
} else if (c10::toString(name) == "aten::arange.start_step") {
649646
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble() ||
650647
args.at(n->input(2)).IValue()->isDouble()) {
651648
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
@@ -659,8 +656,7 @@ auto aten_registrations TORCHTRT_UNUSED =
659656
return torch::arange(start_scalar, end_scalar, step_scalar);
660657
}
661658
} else {
662-
TORCHTRT_THROW_ERROR(
663-
"Invalid input argument size for aten::arange, input argument size: " << input_size);
659+
TORCHTRT_THROW_ERROR("Unsupported aten::arange variant: " << name);
664660
}
665661
return {};
666662
},

0 commit comments

Comments
 (0)