Skip to content

Commit 533215c

Browse files
authored
feat: support chunk dynamo converter (#2401)
1 parent cb20f90 commit 533215c

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+24
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,30 @@ def aten_ops_slice(
667667
)
668668

669669

670+
@dynamo_tensorrt_converter(torch.ops.aten.chunk.default) # type: ignore[misc]
671+
@enforce_tensor_types(
672+
{
673+
0: (TRTTensor,),
674+
}
675+
) # type: ignore[misc]
676+
def aten_ops_chunk(
677+
ctx: ConversionContext,
678+
target: Target,
679+
args: Tuple[Argument, ...],
680+
kwargs: Dict[str, Argument],
681+
name: str,
682+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
683+
return impl.slice.chunk(
684+
ctx,
685+
target,
686+
SourceIR.ATEN,
687+
name,
688+
args[0],
689+
args[1],
690+
args_bounds_check(args, 2, 0),
691+
)
692+
693+
670694
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
671695
@enforce_tensor_types(
672696
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

+55
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,58 @@ def expand(
102102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
103103
set_layer_name(layer, target, name, source_ir)
104104
return layer.get_output(0)
105+
106+
107+
def chunk(
108+
ctx: ConversionContext,
109+
target: Target,
110+
source_ir: Optional[SourceIR],
111+
name: str,
112+
input: TRTTensor,
113+
chunks: int,
114+
dim: int,
115+
) -> TRTTensor:
116+
if chunks <= 0:
117+
raise RuntimeError(
118+
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
119+
)
120+
121+
shape = input.shape
122+
dim = get_positive_dim(dim, len(shape))
123+
124+
if dim >= len(shape):
125+
raise RuntimeError(
126+
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
127+
)
128+
129+
dynamic_shape = has_dynamic_shape(input.shape)
130+
if dynamic_shape > 0:
131+
# Check whether slice target dim is dynamic shape dim
132+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
133+
134+
size_dim = shape[dim]
135+
chunk_size = math.ceil(size_dim / chunks)
136+
result = []
137+
start = 0
138+
end = min(start + chunk_size, size_dim)
139+
cnt = 0
140+
141+
while start < end:
142+
result.append(
143+
slice_op(
144+
ctx,
145+
target,
146+
source_ir,
147+
f"{name}_slice_{cnt}",
148+
input,
149+
dim,
150+
start,
151+
end,
152+
1,
153+
)
154+
)
155+
start = end
156+
end = min(start + chunk_size, size_dim)
157+
cnt += 1
158+
159+
return result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestChunkConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
((1,), 3, 0),
12+
((3,), 3, 0),
13+
((4,), 3, 0),
14+
((6,), 3, 0),
15+
((3,), 1, -1),
16+
((3,), 3, -1),
17+
((3,), 4, -1),
18+
]
19+
)
20+
def test_chunk_1D(self, shape, chunks, dim):
21+
class TestChunk(torch.nn.Module):
22+
def forward(self, input):
23+
out = torch.ops.aten.chunk.default(input, chunks, dim)
24+
return out
25+
26+
input = [torch.randn(shape)]
27+
self.run_test(
28+
TestChunk(),
29+
input,
30+
)
31+
32+
@parameterized.expand(
33+
[
34+
((3, 4), 1, 0),
35+
((3, 4), 3, 0),
36+
((3, 4), 4, 0),
37+
((3, 4), 2, -2),
38+
((3, 4), 6, -2),
39+
((3, 4), 3, 1),
40+
((3, 4), 4, 1),
41+
((3, 4), 5, -1),
42+
]
43+
)
44+
def test_chunk_2D(self, shape, chunks, dim):
45+
class TestChunk(torch.nn.Module):
46+
def forward(self, input):
47+
out = torch.ops.aten.chunk.default(input, chunks, dim)
48+
return out
49+
50+
input = [torch.randn(shape)]
51+
self.run_test(
52+
TestChunk(),
53+
input,
54+
)
55+
56+
@parameterized.expand(
57+
[
58+
((3, 4, 2), 1, 0),
59+
((3, 4, 2), 3, -3),
60+
((3, 4, 2), 3, 1),
61+
((3, 4, 2), 4, 1),
62+
((3, 4, 2), 6, -2),
63+
((3, 4, 2), 1, 2),
64+
((3, 4, 2), 3, -1),
65+
((3, 4, 2), 4, -1),
66+
]
67+
)
68+
def test_chunk_3D(self, shape, chunks, dim):
69+
class TestChunk(torch.nn.Module):
70+
def forward(self, input):
71+
out = torch.ops.aten.chunk.default(input, chunks, dim)
72+
return out
73+
74+
input = [torch.randn(shape)]
75+
self.run_test(
76+
TestChunk(),
77+
input,
78+
)
79+
80+
81+
if __name__ == "__main__":
82+
run_tests()

0 commit comments

Comments
 (0)