@@ -620,22 +620,19 @@ auto aten_registrations TORCHTRT_UNUSED =
620
620
{" aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)" })})
621
621
.evaluator({c10::Symbol::fromQualString (" aten::arange" ),
622
622
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
623
- int input_size = n->inputs ().size ();
624
- int scalar_count = 0 ;
625
- for (int i = 0 ; i < input_size; i++) {
626
- if (args.at (n->input (i)).IValue ()->isScalar ()) {
627
- scalar_count += 1 ;
628
- }
629
- }
630
- if (scalar_count == 1 ) {
623
+ auto schema = n->maybeSchema ();
624
+ TORCHTRT_CHECK (schema, " Unable to get schema for node: " << *n);
625
+ auto name = schema->operator_name ();
626
+
627
+ if (c10::toString (name) == " aten::arange" ) {
631
628
if (args.at (n->input (0 )).IValue ()->isInt ()) {
632
629
int end_scalar = args.at (n->input (0 )).unwrapToInt ();
633
630
return torch::arange (end_scalar);
634
631
} else if (args.at (n->input (0 )).IValue ()->isDouble ()) {
635
632
float end_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
636
633
return torch::arange (end_scalar);
637
634
}
638
- } else if (scalar_count == 2 ) {
635
+ } else if (c10::toString (name) == " aten::arange.start " ) {
639
636
if (args.at (n->input (0 )).IValue ()->isDouble () || args.at (n->input (1 )).IValue ()->isDouble ()) {
640
637
float start_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
641
638
float end_scalar = args.at (n->input (1 )).unwrapToScalar ().to <float >();
@@ -645,7 +642,7 @@ auto aten_registrations TORCHTRT_UNUSED =
645
642
int end_scalar = args.at (n->input (1 )).unwrapToInt ();
646
643
return torch::arange (start_scalar, end_scalar);
647
644
}
648
- } else if (scalar_count == 3 ) {
645
+ } else if (c10::toString (name) == " aten::arange.start_step " ) {
649
646
if (args.at (n->input (0 )).IValue ()->isDouble () || args.at (n->input (1 )).IValue ()->isDouble () ||
650
647
args.at (n->input (2 )).IValue ()->isDouble ()) {
651
648
float start_scalar = args.at (n->input (0 )).unwrapToScalar ().to <float >();
@@ -659,8 +656,7 @@ auto aten_registrations TORCHTRT_UNUSED =
659
656
return torch::arange (start_scalar, end_scalar, step_scalar);
660
657
}
661
658
} else {
662
- TORCHTRT_THROW_ERROR (
663
- " Invalid input argument size for aten::arange, input argument size: " << input_size);
659
+ TORCHTRT_THROW_ERROR (" Unsupported aten::arange variant: " << name);
664
660
}
665
661
return {};
666
662
},
0 commit comments