@@ -223,13 +223,20 @@ auto aten_registrations TORCHTRT_UNUSED =
223
223
{c10::Symbol::fromQualString (" aten::slice" ),
224
224
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
225
225
c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
226
-
227
226
int64_t start = 0 ;
227
+ int64_t end = 9223372036854775807 ;
228
228
auto startIVal = args.at (n->input (1 )).IValue ();
229
+ auto endIVal = args.at (n->input (2 )).IValue ();
230
+
229
231
if (!startIVal->isNone ()) {
230
232
start = args.at (n->input (1 )).unwrapToInt ();
231
233
}
232
- int64_t end = args.at (n->input (2 )).unwrapToInt ();
234
+ if (!endIVal->isNone ()) {
235
+ end = args.at (n->input (2 )).unwrapToInt ();
236
+ }
237
+ if (start > end) {
238
+ LOG_DEBUG (" The end should be greater than start" );
239
+ }
233
240
int64_t step = args.at (n->input (3 )).unwrapToInt ();
234
241
235
242
const int64_t list_size = list.size ();
@@ -253,8 +260,9 @@ auto aten_registrations TORCHTRT_UNUSED =
253
260
254
261
return sliced_list;
255
262
},
256
- EvalOptions ().validSchemas (
257
- {" aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])" })})
263
+ EvalOptions ().validSchemas ({" aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])" })})
264
+ // EvalOptions().validSchemas(
265
+ // {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
258
266
.evaluator(
259
267
{c10::Symbol::fromQualString (" aten::len" ),
260
268
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -896,8 +904,14 @@ auto aten_registrations TORCHTRT_UNUSED =
896
904
auto step = args.at (n->input (2 )).unwrapToInt ();
897
905
return start + idx * step;
898
906
},
899
- EvalOptions ().validSchemas ({" aten::__derive_index(int idx, int start, int step) -> int" })});
900
-
907
+ EvalOptions ().validSchemas ({" aten::__derive_index(int idx, int start, int step) -> int" })})
908
+ .evaluator(
909
+ {c10::Symbol::fromQualString (" aten::list" ),
910
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
911
+ c10::List<c10::IValue> list = args.at (n->input (0 )).IValue ()->to <c10::List<c10::IValue>>();
912
+ return list.copy ();
913
+ },
914
+ EvalOptions ().validSchemas ({" aten::list.t(t[] l) -> (t[])" })});
901
915
} // namespace
902
916
} // namespace evaluators
903
917
} // namespace conversion
0 commit comments