Skip to content

Commit a11287f

Browse files
committed
feat(aten::slice): Patching slice for new optional params
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 254eab2 commit a11287f

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

Diff for: core/conversion/converters/impl/select.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,30 @@ auto select_registrations TRTORCH_UNUSED =
197197
return true;
198198
}})
199199
.pattern(
200-
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)",
200+
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
201201
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
202202
auto in = args[0].ITensor();
203203
auto axis = args[1].unwrapToInt();
204204
auto maxDim = static_cast<int64_t>(in->getDimensions().d[axis]);
205+
auto startIdx = 0;
206+
auto startIdxIVal = args[2].IValue();
207+
if (!startIdxIVal->isNone()) {
208+
startIdx = startIdxIVal->toInt();
209+
}
205210
// Handle case when given tensor index is negative
206-
auto startIdx = args[2].unwrapToInt();
207211
auto start = (startIdx < 0) ? (maxDim + startIdx) : startIdx;
208212
// Bound the end index to input tensor dimensions at specified axis
209-
auto endIdx = std::min(args[3].unwrapToInt(), maxDim);
213+
auto endIdx = maxDim;
214+
auto endIdxIVal = args[3].IValue();
215+
if (!endIdxIVal->isNone()) {
216+
endIdx = std::min(endIdxIVal->toInt(), maxDim);
217+
}
210218
auto end = (endIdx < 0) ? (maxDim + endIdx) : endIdx;
211219
auto step = args[4].unwrapToInt();
212220

221+
LOG_DEBUG("Start idx: " << start);
222+
LOG_DEBUG("End idx: " << end);
223+
213224
// indices to be accessed need to be an at::Tensor
214225
at::Tensor indices = torch::arange(start, end, step).to(torch::kI32);
215226
auto weights = Weights(ctx, indices);

Diff for: tests/core/conversion/converters/test_select.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,15 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
170170
TEST(Converters, ATenSliceConvertsCorrectly) {
171171
const auto graph = R"IR(
172172
graph(%x.1 : Tensor):
173-
%2 : int = prim::Constant[value=9223372036854775807]()
173+
%2 : None = prim::Constant()
174174
%3 : int = prim::Constant[value=2]()
175175
%4 : int = prim::Constant[value=4]()
176176
%5 : int = prim::Constant[value=1]()
177177
%6 : int = prim::Constant[value=0]()
178178
%7 : Tensor = aten::select(%x.1, %6, %6)
179179
%8 : Tensor = aten::select(%7, %6, %5)
180180
%9 : Tensor = aten::slice(%8, %6, %5, %4, %3)
181-
%10 : Tensor = aten::slice(%9, %5, %6, %2, %5)
181+
%10 : Tensor = aten::slice(%9, %5, %2, %2, %5)
182182
return (%10))IR";
183183

184184
auto g = std::make_shared<torch::jit::Graph>();

0 commit comments

Comments
 (0)