Skip to content

Commit ee21e73

Browse files
committed
aten::split converter
1 parent c3a65ef commit ee21e73

File tree

3 files changed

+213
-2
lines changed

3 files changed

+213
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Tuple, Union
33

4+
import tensorrt as trt
45
import torch
56
from torch.fx.node import Argument, Node, Target
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
78
from torch_tensorrt.dynamo.conversion import impl
89
from torch_tensorrt.dynamo.conversion.converter_utils import (
910
cast_int_int_div_trt_tensor,
1011
cast_trt_tensor,
12+
dynamic_unsupported,
1113
)
1214
from torch_tensorrt.fx.converters import acc_ops_converters
1315
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1416

15-
import tensorrt as trt
16-
1717
from .converter_registry import dynamo_tensorrt_converter
1818

1919
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -279,6 +279,22 @@ def aten_ops_unsqueeze(
279279
)
280280

281281

282+
@dynamo_tensorrt_converter(
283+
torch.ops.aten.split.default, capability_validator=dynamic_unsupported
284+
)
285+
@dynamo_tensorrt_converter(
286+
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported
287+
)
288+
def aten_ops_split(
289+
network: TRTNetwork,
290+
target: Target,
291+
args: Tuple[Argument, ...],
292+
kwargs: Dict[str, Argument],
293+
name: str,
294+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
295+
return impl.split(network, target, SourceIR.ATEN, name, args[0], args[1], args[3])
296+
297+
282298
@dynamo_tensorrt_converter(torch.ops.aten._softmax.default)
283299
def aten_ops_softmax(
284300
network: TRTNetwork,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
2+
3+
import numpy as np
4+
import torch
5+
import torch_tensorrt as trt
6+
from torch import Tensor
7+
from torch.fx.node import Target
8+
from torch_tensorrt.dynamo._SourceIR import SourceIR
9+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
10+
from torch_tensorrt.fx.converters.converter_utils import (
11+
has_dynamic_shape,
12+
set_layer_name,
13+
)
14+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
15+
16+
17+
def acc_ops_split(
18+
network: TRTNetwork,
19+
target: Target,
20+
source_ir: Optional[SourceIR],
21+
name: str,
22+
input: TRTTensor,
23+
split_size_or_sections: Union[int, List(int)],
24+
dim: Optional[Any] = 0,
25+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
26+
if not isinstance(input, TRTTensor):
27+
raise RuntimeError(
28+
f"split received input {input} that is not part " "of the TensorRT region!"
29+
)
30+
31+
dim = cast(int, dim)
32+
dynamic_shape = has_dynamic_shape(input.shape)
33+
if network.has_implicit_batch_dimension:
34+
assert dim != 0, "Can't split on batch dim when it's implicit!"
35+
dim -= 1
36+
else:
37+
if dynamic_shape > 0:
38+
# Check whether slice target dim is dynamic shape dim
39+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
40+
41+
split_sizes = []
42+
if type(split_size_or_sections) == int:
43+
split_sizes.append(cast(int, split_size_or_sections))
44+
else:
45+
for split_size_or_section in split_size_or_sections:
46+
split_sizes.append(cast(int, split_size_or_section))
47+
48+
start = [0] * len(input.shape)
49+
stride = [1] * len(start)
50+
offset = 0
51+
52+
if len(split_sizes) == 1:
53+
num_splits = input.shape[dim] + split_sizes[0] - 1 // split_sizes[0]
54+
split_sizes = [split_sizes[0]] * num_splits
55+
else:
56+
num_splits = len(split_sizes)
57+
58+
if num_splits < 1:
59+
raise RuntimeError(
60+
f"Invalid split: {input.shape[dim]} with split_size={split_sizes}"
61+
)
62+
63+
max_offset = input.shape[dim]
64+
# add slice layers
65+
output = []
66+
for i in range(num_splits):
67+
shape = list(input.shape)
68+
shape[dim] = min(split_sizes[i], cast(int, max_offset - offset))
69+
start[dim] = offset
70+
if dynamic_shape:
71+
shape = get_shape_with_dynamic_shape(
72+
network, shape, input, target, f"{name}_shape_{i}"
73+
)
74+
layer = network.add_slice(
75+
input, start=start, shape=[] if dynamic_shape else shape, stride=stride
76+
)
77+
if dynamic_shape:
78+
layer.set_input(2, shape)
79+
offset += split_sizes[i]
80+
set_layer_name(layer, target, f"{name}_{i}")
81+
output.append(layer.get_output(0))
82+
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
from harness import DispatchTestCase
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
8+
# FIXME: check about implicit and explicit batch
9+
class TestSliceConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("select_split_size_or_sections_no_dim", 2),
13+
("select_split_size_or_sections_list_no_dim", [1, 4]),
14+
("select_split_size_or_sections_list_no_dim_not_full_split", [1, 3]),
15+
]
16+
)
17+
def test_slice(self, _, split_size_or_tensor):
18+
class TestModule(torch.nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(self, input):
23+
out = torch.ops.aten.slice.Tensor(input, split_size_or_tensor)
24+
return out
25+
26+
input = torch.arange(10).reshape(5, 2)
27+
self.run_test(
28+
TestModule(),
29+
input,
30+
expected_ops={torch.ops.aten.slice.Tensor},
31+
)
32+
33+
34+
class TestSliceConverter(DispatchTestCase):
35+
@parameterized.expand(
36+
[
37+
("select_split_size_or_sections_dim", 2, 1),
38+
("select_split_size_or_sections_list_dim", [1, 4], 1),
39+
("select_split_size_or_sections_list_dim_not_full_split", [1, 3], 1),
40+
]
41+
)
42+
def test_slice(self, _, split_size_or_tensor, dim):
43+
class TestModule(torch.nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
47+
def forward(self, input):
48+
out = torch.ops.aten.slice.Tensor(split_size_or_tensor, dim)
49+
return out
50+
51+
input = torch.arange(10).reshape(2, 5)
52+
self.run_test(
53+
TestModule(),
54+
input,
55+
expected_ops={torch.ops.aten.slice.Tensor},
56+
)
57+
58+
59+
class TestSliceConverter(DispatchTestCase):
60+
@parameterized.expand(
61+
[
62+
("select_split_size_or_sections_dim", 2, 1),
63+
("select_split_size_or_sections_list_dim", [1, 4], 1),
64+
]
65+
)
66+
def test_slice(self, _, split_size_or_tensor, dim):
67+
class TestModule(torch.nn.Module):
68+
def __init__(self):
69+
super().__init__()
70+
71+
def forward(self, input):
72+
out = torch.ops.aten.slice.Tensor(input, split_size_or_tensor, dim)
73+
return out
74+
75+
input_specs = [
76+
Input(
77+
shape=(1, 10, -1),
78+
dtype=torch.float32,
79+
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
80+
),
81+
]
82+
self.run_test_with_dynamic_shape(
83+
TestModule(),
84+
input_specs,
85+
expected_ops={torch.ops.aten.slice.Tensor},
86+
)
87+
88+
89+
class TestSplitSymIntConverterImplicitBatch(DispatchTestCase):
90+
@parameterized.expand(
91+
[
92+
("select_chunk_dim", 6, 0),
93+
]
94+
)
95+
def test_chunk(self, _, chunk, dim):
96+
class TestModule(torch.nn.Module):
97+
def __init__(self):
98+
super().__init__()
99+
100+
def forward(self, input):
101+
out = torch.ops.aten.chunk(input, chunk, dim)
102+
return out
103+
104+
input = [torch.randn(11)]
105+
self.run_test(
106+
TestModule(),
107+
input,
108+
expected_ops={torch.ops.aten.split},
109+
)
110+
111+
112+
if __name__ == "__main__":
113+
run_tests()

0 commit comments

Comments
 (0)