|
11 | 11 | get_positive_dim,
|
12 | 12 | get_trt_tensor,
|
13 | 13 | )
|
| 14 | +from torch_tensorrt.dynamo.conversion.impl.cat import cat |
14 | 15 | from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
|
15 | 16 | from torch_tensorrt.fx.converters.converter_utils import (
|
16 | 17 | has_dynamic_shape,
|
@@ -99,7 +100,45 @@ def expand(
|
99 | 100 | [int(i == o) for i, o in zip(input_tensor_shape, shape)]
|
100 | 101 | ) # stride == 1 if dimensions match, 0 otherwise
|
101 | 102 |
|
102 |
| - layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride) |
| 103 | + shape_ = shape |
| 104 | + # Handle dynamic shapes case where shape has dynamic dimension |
| 105 | + if any(isinstance(ele, TRTTensor) for ele in shape): |
| 106 | + shape_ = cat( |
| 107 | + ctx, |
| 108 | + target, |
| 109 | + source_ir, |
| 110 | + name + "_shape_concat", |
| 111 | + shape, |
| 112 | + 0, |
| 113 | + cast_dtype=trt.int32, |
| 114 | + ) |
| 115 | + start_tensor = cat( |
| 116 | + ctx, |
| 117 | + target, |
| 118 | + source_ir, |
| 119 | + name + "_start_concat", |
| 120 | + start, |
| 121 | + 0, |
| 122 | + cast_dtype=trt.int32, |
| 123 | + ) |
| 124 | + stride_tensor = cat( |
| 125 | + ctx, |
| 126 | + target, |
| 127 | + source_ir, |
| 128 | + name + "_stride_concat", |
| 129 | + stride, |
| 130 | + 0, |
| 131 | + cast_dtype=trt.int32, |
| 132 | + ) |
| 133 | + layer = ctx.net.add_slice( |
| 134 | + input_t, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims() |
| 135 | + ) |
| 136 | + layer.set_input(1, start_tensor) |
| 137 | + layer.set_input(2, shape_) |
| 138 | + layer.set_input(3, stride_tensor) |
| 139 | + else: |
| 140 | + layer = ctx.net.add_slice(input_t, start=start, shape=shape_, stride=stride) |
| 141 | + |
103 | 142 | set_layer_name(layer, target, name, source_ir)
|
104 | 143 | return layer.get_output(0)
|
105 | 144 |
|
|
0 commit comments