Skip to content

Commit 5ef1dec

Browse files
authored
chore: Minor fix 2.3 (#2866)
1 parent 25a04ae commit 5ef1dec

File tree

1 file changed

+40
-1
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/slice

1 file changed

+40
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_positive_dim,
1212
get_trt_tensor,
1313
)
14+
from torch_tensorrt.dynamo.conversion.impl.cat import cat
1415
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
1516
from torch_tensorrt.fx.converters.converter_utils import (
1617
has_dynamic_shape,
@@ -99,7 +100,45 @@ def expand(
99100
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
100101
) # stride == 1 if dimensions match, 0 otherwise
101102

102-
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
103+
shape_ = shape
104+
# Handle dynamic shapes case where shape has dynamic dimension
105+
if any(isinstance(ele, TRTTensor) for ele in shape):
106+
shape_ = cat(
107+
ctx,
108+
target,
109+
source_ir,
110+
name + "_shape_concat",
111+
shape,
112+
0,
113+
cast_dtype=trt.int32,
114+
)
115+
start_tensor = cat(
116+
ctx,
117+
target,
118+
source_ir,
119+
name + "_start_concat",
120+
start,
121+
0,
122+
cast_dtype=trt.int32,
123+
)
124+
stride_tensor = cat(
125+
ctx,
126+
target,
127+
source_ir,
128+
name + "_stride_concat",
129+
stride,
130+
0,
131+
cast_dtype=trt.int32,
132+
)
133+
layer = ctx.net.add_slice(
134+
input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
135+
)
136+
layer.set_input(1, start_tensor)
137+
layer.set_input(2, shape_)
138+
layer.set_input(3, stride_tensor)
139+
else:
140+
layer = ctx.net.add_slice(input_t, start=start, shape=shape_, stride=stride)
141+
103142
set_layer_name(layer, target, name, source_ir)
104143
return layer.get_output(0)
105144

0 commit comments

Comments
 (0)