From 6ea1446c49e9f4cc60c37fb0c4d74cbe894c783b Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 28 May 2024 13:24:44 -0700 Subject: [PATCH] chore: minor fix --- .../dynamo/conversion/impl/slice/ops.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 61d71fe9a0..8d5ba644a0 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -11,6 +11,7 @@ get_positive_dim, get_trt_tensor, ) +from torch_tensorrt.dynamo.conversion.impl.cat import cat from torch_tensorrt.dynamo.conversion.impl.slice.base import slice from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, @@ -99,7 +100,45 @@ def expand( [int(i == o) for i, o in zip(input_tensor_shape, shape)] ) # stride == 1 if dimensions match, 0 otherwise - layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) + shape_ = shape + # Handle dynamic shapes case where shape has dynamic dimension + if any(isinstance(ele, TRTTensor) for ele in shape): + shape_ = cat( + ctx, + target, + source_ir, + name + "_shape_concat", + shape, + 0, + cast_dtype=trt.int32, + ) + start_tensor = cat( + ctx, + target, + source_ir, + name + "_start_concat", + start, + 0, + cast_dtype=trt.int32, + ) + stride_tensor = cat( + ctx, + target, + source_ir, + name + "_stride_concat", + stride, + 0, + cast_dtype=trt.int32, + ) + layer = ctx.net.add_slice( + input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims() + ) + layer.set_input(1, start_tensor) + layer.set_input(2, shape_) + layer.set_input(3, stride_tensor) + else: + layer = ctx.net.add_slice(input_t, start=start, shape=shape_, stride=stride) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0)