Skip to content

Commit 19aabdd

Browse files
apbosegs-olive
andauthored
aten::split converter (#2232)
Co-authored-by: gs-olive <[email protected]>
1 parent c875c39 commit 19aabdd

File tree

5 files changed

+322
-16
lines changed

5 files changed

+322
-16
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+29
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1212

1313
from .converter_registry import dynamo_tensorrt_converter
14+
from .converter_utils import dynamic_unsupported_with_args
1415

1516
_LOGGER: logging.Logger = logging.getLogger(__name__)
1617

@@ -354,6 +355,34 @@ def aten_ops_softmax(
354355
)
355356

356357

358+
@dynamo_tensorrt_converter(
359+
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
360+
)
361+
@dynamo_tensorrt_converter(
362+
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
363+
)
364+
@dynamo_tensorrt_converter(
365+
torch.ops.aten.split_with_sizes.default,
366+
capability_validator=dynamic_unsupported_with_args([1]),
367+
)
368+
def aten_ops_split(
369+
network: TRTNetwork,
370+
target: Target,
371+
args: Tuple[Argument, ...],
372+
kwargs: Dict[str, Argument],
373+
name: str,
374+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
375+
return impl.split.split(
376+
network,
377+
target,
378+
SourceIR.ATEN,
379+
name,
380+
input=args[0],
381+
split_size_or_sections=args[1],
382+
dim=args_bounds_check(args, 2, 0),
383+
)
384+
385+
357386
@dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc]
358387
def aten_ops_where(
359388
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import functools
22
import logging
33
import re
4-
from typing import Any, List, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import numpy as np
77
import tensorrt as trt
88
import torch
9+
from torch import SymBool, SymFloat, SymInt
910
from torch.fx.node import Target
1011
from torch_tensorrt.fx.converters.converter_utils import (
1112
Frameworks,
@@ -60,34 +61,53 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
6061

6162

6263
def dynamic_unsupported(node: torch.fx.Node) -> bool:
64+
"""Validates that a node has no dynamic args, kwargs, or outputs"""
65+
return _dynamic_unsupported(node=node)
66+
67+
68+
def dynamic_unsupported_with_args(
69+
arg_positions_to_check: Optional[List[int]] = None,
70+
) -> Callable[[torch.fx.Node], bool]:
71+
"""Returns a validator that a node has no dynamic args at specific positions"""
72+
return functools.partial(
73+
_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check
74+
)
75+
76+
77+
def _dynamic_unsupported(
78+
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
79+
) -> bool:
6380
# Validate that none of the inputs to the node have Dynamic shapes
6481
assert isinstance(
6582
node, torch.fx.Node
6683
), "Inputs to validator functions must be FX Nodes"
6784

85+
def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
86+
"""Checks if a node itself has Dynamic properties"""
87+
return getattr(
88+
subnode.meta["val"], "_has_symbolic_sizes_strides", False
89+
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))
90+
6891
# Check node value itself
69-
if ("val" in node.meta) and getattr(
70-
node.meta["val"], "_has_symbolic_sizes_strides", False
71-
):
92+
if arg_positions_to_check is None and _is_subnode_dynamic(node):
7293
return False
7394

7495
# Check node arguments individually
75-
if any(
76-
(
77-
("val" in arg.meta)
78-
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
79-
)
80-
for arg in node.args
81-
if isinstance(arg, torch.fx.Node)
96+
if arg_positions_to_check is None and any(
97+
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
98+
):
99+
return False
100+
# Check specific arg positions if the caller has specified positions to check
101+
elif arg_positions_to_check is not None and any(
102+
_is_subnode_dynamic(node.args[i])
103+
for i in arg_positions_to_check
104+
if isinstance(node.args[i], torch.fx.Node)
82105
):
83106
return False
84107

85108
# Check node keyword arguments individually
86-
if any(
87-
(
88-
("val" in kwarg.meta)
89-
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
90-
)
109+
if arg_positions_to_check is None and any(
110+
_is_subnode_dynamic(kwarg)
91111
for kwarg in node.kwargs.values()
92112
if isinstance(kwarg, torch.fx.Node)
93113
):

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
select,
1616
shape,
1717
slice,
18+
split,
1819
squeeze,
1920
unary,
2021
unsqueeze,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
2+
3+
import numpy as np
4+
import torch
5+
import torch_tensorrt as trt
6+
from torch import Tensor
7+
from torch.fx.node import Target
8+
from torch_tensorrt.dynamo._SourceIR import SourceIR
9+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
10+
from torch_tensorrt.fx.converters.converter_utils import (
11+
has_dynamic_shape,
12+
set_layer_name,
13+
)
14+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
15+
16+
17+
def split(
18+
network: TRTNetwork,
19+
target: Target,
20+
source_ir: Optional[SourceIR],
21+
name: str,
22+
input: TRTTensor,
23+
split_size_or_sections: Union[int, List[int]],
24+
dim: int = 0,
25+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
26+
if not isinstance(input, TRTTensor):
27+
raise RuntimeError(
28+
f"split received input {input} that is not part " "of the TensorRT region!"
29+
)
30+
31+
dynamic_shape = has_dynamic_shape(input.shape)
32+
if dynamic_shape > 0:
33+
# Check whether slice target dim is dynamic shape dim
34+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
35+
36+
split_sizes = []
37+
if isinstance(split_size_or_sections, int):
38+
split_sizes.append(split_size_or_sections)
39+
else:
40+
for split_size_or_section in split_size_or_sections:
41+
split_sizes.append(split_size_or_section)
42+
43+
start = [0] * len(input.shape)
44+
stride = [1] * len(start)
45+
offset = 0
46+
if len(split_sizes) == 1:
47+
num_splits = (input.shape[dim] + split_sizes[0] - 1) // split_sizes[0]
48+
split_sizes = [split_sizes[0]] * num_splits
49+
else:
50+
num_splits = len(split_sizes)
51+
sum_split_sizes = sum(split_sizes)
52+
if sum_split_sizes != input.shape[dim]:
53+
raise RuntimeError(
54+
f"split sizes don't add up to the tensor's size in the given dimension"
55+
)
56+
57+
if num_splits < 1:
58+
raise RuntimeError(
59+
f"Invalid split: {input.shape[dim]} with split_size={split_sizes}"
60+
)
61+
62+
max_offset = input.shape[dim]
63+
# add slice layers
64+
output = []
65+
for i in range(num_splits):
66+
shape = list(input.shape)
67+
shape[dim] = min(split_sizes[i], max_offset - offset)
68+
start[dim] = offset
69+
if dynamic_shape:
70+
shape = get_shape_with_dynamic_shape(
71+
network, target, source_ir, f"{name}_shape_{i}", shape, input
72+
)
73+
layer = network.add_slice(
74+
input, start=start, shape=[] if dynamic_shape else shape, stride=stride
75+
)
76+
if dynamic_shape:
77+
layer.set_input(2, shape)
78+
offset += split_sizes[i]
79+
set_layer_name(layer, target, f"{name}_{i}")
80+
output.append(layer.get_output(0))
81+
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import torch
2+
from .harness import DispatchTestCase
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
7+
8+
9+
# FIXME: check about implicit and explicit batch
10+
class TestSplitConverterNoDim(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("split_size_or_sections_no_dim", 2),
14+
]
15+
)
16+
def test_split(self, _, split_size_or_tensor):
17+
class TestModule(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, input):
22+
out = torch.split(input, split_size_or_tensor)
23+
return out
24+
25+
input = [torch.randn(10).reshape(5, 2)]
26+
self.run_test(
27+
TestModule(),
28+
input,
29+
expected_ops={torch.ops.aten.split.Tensor},
30+
disable_passes=True,
31+
)
32+
33+
@parameterized.expand(
34+
[
35+
("split_size_or_sections_list_no_dim_list", [1, 4]),
36+
]
37+
)
38+
def test_split_list(self, _, split_size_or_tensor):
39+
class TestModule(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
def forward(self, input):
44+
out = torch.split(input, split_size_or_tensor)
45+
return out
46+
47+
input = [torch.randn(10).reshape(5, 2)]
48+
self.run_test(
49+
TestModule(),
50+
input,
51+
expected_ops={torch.ops.aten.split_with_sizes.default},
52+
disable_passes=True,
53+
)
54+
55+
@parameterized.expand(
56+
[
57+
("split_size_or_sections_dims", 2, 1),
58+
]
59+
)
60+
def test_split(self, _, split_size_or_tensor, dim):
61+
class TestModule(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
65+
def forward(self, input):
66+
out = torch.split(input, split_size_or_tensor, dim)
67+
return out
68+
69+
input = [torch.randn(10).reshape(5, 2)]
70+
self.run_test(
71+
TestModule(),
72+
input,
73+
expected_ops={torch.ops.aten.split.Tensor},
74+
disable_passes=True,
75+
)
76+
77+
@parameterized.expand(
78+
[
79+
("split_size_or_sections_list_dims", [1, 1], 1),
80+
]
81+
)
82+
def test_split_dim_list(self, _, split_size_or_tensor, dim):
83+
class TestModule(torch.nn.Module):
84+
def __init__(self):
85+
super().__init__()
86+
87+
def forward(self, input):
88+
out = torch.split(input, split_size_or_tensor, dim)
89+
return out
90+
91+
input = [torch.randn(10).reshape(5, 2)]
92+
self.run_test(
93+
TestModule(),
94+
input,
95+
expected_ops={torch.ops.aten.split_with_sizes.default},
96+
disable_passes=True,
97+
)
98+
99+
@parameterized.expand(
100+
[
101+
("split_size_or_sections_list_dims_not_full_list", [1, 1], 1),
102+
]
103+
)
104+
def test_split_dim_list(self, _, split_size_or_tensor, dim):
105+
class TestModule(torch.nn.Module):
106+
def __init__(self):
107+
super().__init__()
108+
109+
def forward(self, input):
110+
out = torch.split(input, split_size_or_tensor, dim)
111+
return out
112+
113+
input = [torch.randn(15).reshape(5, 3)]
114+
with self.assertRaises(RuntimeError):
115+
self.run_test(
116+
TestModule(),
117+
input,
118+
expected_ops={torch.ops.aten.split_with_sizes.default},
119+
disable_passes=True,
120+
)
121+
122+
@parameterized.expand(
123+
[
124+
("select_split_size_or_sections_dim_dynamic_shape", 2, 1),
125+
]
126+
)
127+
def test_split_dynamic(self, _, split_size_or_tensor, dim):
128+
class TestModule(torch.nn.Module):
129+
def __init__(self):
130+
super().__init__()
131+
132+
def forward(self, input):
133+
out = torch.split(input, split_size_or_tensor, dim)
134+
return out
135+
136+
input_specs = [
137+
Input(
138+
shape=(1, 10, -1),
139+
dtype=torch.float32,
140+
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
141+
),
142+
]
143+
self.run_test_with_dynamic_shape(
144+
TestModule(),
145+
input_specs,
146+
expected_ops={torch.ops.aten.split.Tensor},
147+
disable_passes=True,
148+
)
149+
150+
@parameterized.expand(
151+
[
152+
("select_chunk_dim", 6, 0),
153+
]
154+
)
155+
def test_split_dynamic(self, _, chunk, dim):
156+
class TestModule(torch.nn.Module):
157+
def __init__(self):
158+
super().__init__()
159+
160+
def forward(self, input):
161+
out = torch.ops.aten.chunk(input, chunk, dim)
162+
return out
163+
164+
input = [torch.randn(11)]
165+
with self.assertRaises(UnsupportedOperatorException):
166+
self.run_test(
167+
TestModule(),
168+
input,
169+
expected_ops={torch.ops.aten.split.Tensor},
170+
disable_passes=True,
171+
)
172+
173+
174+
if __name__ == "__main__":
175+
run_tests()

0 commit comments

Comments
 (0)