Skip to content

Commit d3454dd

Browse files
committed
fix: allow slice_scatter decomposition with SymInt parameters for the case when returning the source tensor
1 parent 9b78101 commit d3454dd

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,17 @@ def slice_scatter_decomposition(
203203
if step is None:
204204
step = 1
205205

206-
# Ensure start, end, and step are all integers
207-
assert isinstance(start, int), "start must be an integer"
208-
assert isinstance(end, int), "end must be an integer"
209-
assert isinstance(step, int), "step must be an integer"
210-
211206
src_dim = src_tensor.shape
212207
# step == 0 is not a valid torch case
213208
# also src_dim should be equal to slice dimension
214209

215210
if start == 0 and end == dim_size and step == 1:
216211
return src_tensor
212+
213+
# Ensure start, end, and step are all integers
214+
assert isinstance(start, int), "start must be an integer"
215+
assert isinstance(end, int), "end must be an integer"
216+
assert isinstance(step, int), "step must be an integer"
217217

218218
cat_tensors = []
219219
index_tensor_shape = []

0 commit comments

Comments
 (0)