@@ -197,19 +197,30 @@ auto select_registrations TRTORCH_UNUSED =
197
197
return true ;
198
198
}})
199
199
.pattern(
200
- {" aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0 , int end=9223372036854775807 , int step=1) -> Tensor(a)" ,
200
+ {" aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None , int? end=None , int step=1) -> Tensor(a)" ,
201
201
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202
202
auto in = args[0 ].ITensor ();
203
203
auto axis = args[1 ].unwrapToInt ();
204
204
auto maxDim = static_cast <int64_t >(in->getDimensions ().d [axis]);
205
+ auto startIdx = 0 ;
206
+ auto startIdxIVal = args[2 ].IValue ();
207
+ if (!startIdxIVal->isNone ()) {
208
+ startIdx = startIdxIVal->toInt ();
209
+ }
205
210
// Handle case when given tensor index is negative
206
- auto startIdx = args[2 ].unwrapToInt ();
207
211
auto start = (startIdx < 0 ) ? (maxDim + startIdx) : startIdx;
208
212
// Bound the end index to input tensor dimensions at specified axis
209
- auto endIdx = std::min (args[3 ].unwrapToInt (), maxDim);
213
+ auto endIdx = maxDim;
214
+ auto endIdxIVal = args[3 ].IValue ();
215
+ if (!endIdxIVal->isNone ()) {
216
+ endIdx = std::min (endIdxIVal->toInt (), maxDim);
217
+ }
210
218
auto end = (endIdx < 0 ) ? (maxDim + endIdx) : endIdx;
211
219
auto step = args[4 ].unwrapToInt ();
212
220
221
+ LOG_DEBUG (" Start idx: " << start);
222
+ LOG_DEBUG (" End idx: " << end);
223
+
213
224
// indices to be accessed need to be an at::Tensor
214
225
at::Tensor indices = torch::arange (start, end, step).to (torch::kI32 );
215
226
auto weights = Weights (ctx, indices);
0 commit comments