Skip to content

Commit b251987

Browse files
committed
fix/feat: Add and repair multiple converters
- Focus on SD-performance-accelerating converters - Add test cases for converters to avoid regressions - Add prims sum converter
1 parent 16c670a commit b251987

File tree

11 files changed

+255
-34
lines changed

11 files changed

+255
-34
lines changed

py/torch_tensorrt/dynamo/conversion/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from ._TRTInterpreter import * # noqa: F403
33
from .aten_ops_converters import * # noqa: F403
44
from .conversion import * # noqa: F403
5-
from .op_evaluators import * # noqa: F403
5+
from .ops_evaluators import * # noqa: F403
6+
from .prims_ops_converters import * # noqa: F403
67
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ def args_bounds_check(
2727
return args[i] if len(args) > i else replacement
2828

2929

30+
def get_ir(target: Target) -> SourceIR:
31+
target_module = getattr(target, "__module__", "None")
32+
if any(
33+
target_module.startswith(prefix)
34+
for prefix in ("torch.ops.prims", "torch._ops.prims")
35+
):
36+
return SourceIR.ATEN
37+
elif any(
38+
target_module.startswith(prefix)
39+
for prefix in ("torch.ops.prims", "torch._ops.prims")
40+
):
41+
return SourceIR.PRIM
42+
elif target_module.startswith("torch.nn"):
43+
return SourceIR.NN
44+
45+
return SourceIR.UNKNOWN
46+
47+
3048
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
3149
def aten_ops_batch_norm(
3250
ctx: ConversionContext,
@@ -674,23 +692,37 @@ def aten_ops_amax(
674692

675693
@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc]
676694
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc]
695+
@dynamo_tensorrt_converter(torch.ops.prims.sum.default) # type: ignore[misc]
677696
def aten_ops_sum(
678697
ctx: ConversionContext,
679698
target: Target,
680699
args: Tuple[Argument, ...],
681700
kwargs: Dict[str, Argument],
682701
name: str,
683702
) -> Union[TRTTensor, Sequence[TRTTensor]]:
684-
return impl.reduce.sum(
703+
sum_ = impl.reduce.sum(
685704
ctx,
686705
target,
687-
SourceIR.ATEN,
706+
get_ir(target),
688707
name,
689708
args[0],
690709
args_bounds_check(args, 1, replacement=None),
691710
args_bounds_check(args, 2, replacement=False),
692711
)
693712

713+
if kwargs.get("output_dtype", None) is not None:
714+
return impl.cast.to_copy(
715+
ctx,
716+
target,
717+
SourceIR.ATEN,
718+
name,
719+
sum_,
720+
kwargs["output_dtype"],
721+
force_layer=False,
722+
)
723+
else:
724+
return sum_
725+
694726

695727
@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
696728
def aten_ops_exp(
@@ -1189,6 +1221,7 @@ def aten_ops_sub(
11891221
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
11901222
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
11911223
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
1224+
@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc]
11921225
def aten_ops_div(
11931226
ctx: ConversionContext,
11941227
target: Target,
@@ -1202,7 +1235,7 @@ def aten_ops_div(
12021235
return impl.elementwise.div(
12031236
ctx,
12041237
target,
1205-
SourceIR.ATEN,
1238+
get_ir(target),
12061239
name,
12071240
args[0],
12081241
args[1],
@@ -1211,7 +1244,7 @@ def aten_ops_div(
12111244
return impl.elementwise.floor_divide(
12121245
ctx,
12131246
target,
1214-
SourceIR.ATEN,
1247+
get_ir(target),
12151248
name,
12161249
args[0],
12171250
args[1],
@@ -1220,7 +1253,7 @@ def aten_ops_div(
12201253
return impl.elementwise.trunc_div(
12211254
ctx,
12221255
target,
1223-
SourceIR.ATEN,
1256+
get_ir(target),
12241257
name,
12251258
args[0],
12261259
args[1],
@@ -1553,5 +1586,5 @@ def tensorrt_scaled_dot_product_attention(
15531586
name: str,
15541587
) -> Union[TRTTensor, Sequence[TRTTensor]]:
15551588
return impl.attention.scaled_dot_product_attention(
1556-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
1589+
ctx, target, SourceIR.TORCHTRT_LOWERED, name, args[0], args[1], args[2]
15571590
)

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import numpy as np
34
import tensorrt as trt
45
import torch
56
from torch.fx.node import Target
@@ -23,16 +24,6 @@ def where(
2324
other: TRTTensor,
2425
condition: TRTTensor,
2526
) -> TRTTensor:
26-
input_dim = len(tuple(input.shape))
27-
other_dim = len(tuple(other.shape))
28-
condition_dim = len(tuple(condition.shape))
29-
30-
if type(input) != TRTTensor:
31-
assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!"
32-
33-
if type(other) != TRTTensor:
34-
assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!"
35-
3627
if not (broadcastable(input, other)):
3728
assert "The two torch tensors should be broadcastable"
3829

@@ -49,33 +40,37 @@ def where(
4940
x_shape = list(input.shape)
5041
y_shape = list(other.shape)
5142
condition_shape = list(condition.shape)
43+
5244
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
5345

5446
# expand shape
55-
if type(condition) != TRTTensor:
56-
assert condition.dtype == torch.bool, "condition dtype is not bool"
47+
if not isinstance(condition, TRTTensor):
48+
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
5749
if condition_shape != output_shape:
58-
condition.expand(output_shape)
59-
condition = condition.to(torch.int32)
60-
condition_const = get_trt_tensor(ctx, condition, f"{name}_condition")
61-
condition_layer = ctx.net.add_identity(condition_const)
62-
condition_layer.set_output_type(0, trt.bool)
63-
set_layer_name(condition_layer, target, f"{name}_condition")
64-
condition_val = condition_layer.get_output(0)
50+
condition = (
51+
condition.expand(output_shape)
52+
if isinstance(condition, torch.Tensor)
53+
else np.broadcast_to(condition, output_shape)
54+
)
55+
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
6556
else:
6657
assert condition.dtype == trt.bool, "mask dtype is not bool!"
67-
if len(condition_shape) != condition_dim:
58+
if condition_shape != output_shape:
6859
condition_val = expand(
6960
ctx, target, source_ir, f"{name}_expand", condition, output_shape
7061
)
7162
else:
7263
condition_val = condition
7364

74-
if type(input) != TRTTensor:
65+
if not isinstance(input, TRTTensor):
7566
if x_shape != output_shape:
7667
# special case where 1 element in input
7768
if len(input.shape) == 0:
78-
input = input.unsqueeze(0)
69+
input = (
70+
input.unsqueeze(0)
71+
if isinstance(input, torch.Tensor)
72+
else np.expand_dims(input, axis=0)
73+
)
7974
input = input.expand(output_shape)
8075
x_val = get_trt_tensor(ctx, input, f"{name}_x")
8176
else:
@@ -85,11 +80,15 @@ def where(
8580
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
8681
)
8782

88-
if type(other) != TRTTensor:
83+
if not isinstance(other, TRTTensor):
8984
if y_shape != output_shape:
9085
# special case where 1 element in other
9186
if len(other.shape) == 0:
92-
other = other.unsqueeze(0)
87+
other = (
88+
other.unsqueeze(0)
89+
if isinstance(other, torch.Tensor)
90+
else np.expand_dims(other, axis=0)
91+
)
9392
other = other.expand(output_shape)
9493
y_val = get_trt_tensor(ctx, other, f"{name}_y")
9594
else:

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def sum(
5151
):
5252
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
5353

54-
if dim is None:
54+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5555
dim = tuple(range(len(input_val.shape)))
5656
layer = ctx.net.add_reduce(
5757
input_val,

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, cast
1+
from typing import List, Optional, Sequence, cast
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -49,3 +49,42 @@ def unsqueeze(
4949
)
5050
set_layer_name(layer, target, name, source_ir)
5151
return layer.get_output(0)
52+
53+
54+
def broadcast_in_dim(
55+
ctx: ConversionContext,
56+
target: Target,
57+
source_ir: Optional[SourceIR],
58+
name: str,
59+
input_t: TRTTensor,
60+
shape: Sequence[int],
61+
broadcast_dimensions: Sequence[int],
62+
) -> TRTTensor:
63+
augmented_shape_list: List[Optional[int]] = list(shape)
64+
65+
# For each dimension being broadcasted, set the augmented shape to None
66+
for broadcast_dim in broadcast_dimensions:
67+
augmented_shape_list[broadcast_dim] = None
68+
69+
# TODO: Expand support to arbitrary broadcasts
70+
assert all(
71+
dim in (1, None) for dim in augmented_shape_list
72+
), "broadcast_in_dim currently only supports unsqueeze broadcasting"
73+
74+
# Unsqueeze the shape repeatedly to broadcast
75+
output = input_t
76+
for idx, x in enumerate(augmented_shape_list):
77+
# If the value is not None, that dimension is to be broadcasted
78+
if x is not None:
79+
output = unsqueeze(
80+
ctx,
81+
target,
82+
source_ir,
83+
name + f"_unsqueeze_for_broadcast_{idx}",
84+
output,
85+
idx,
86+
)
87+
88+
assert tuple(output.shape) == tuple(shape), "broadcast_in_dim shapes don't match"
89+
90+
return output

py/torch_tensorrt/dynamo/conversion/op_evaluators.py renamed to py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def getitem_validator(getitem_node: Node) -> bool:
1919

2020

2121
# TODO: Subsequent evaluators should be registered here with their own validators
22-
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
22+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc]
2323
def generic_evaluator(
2424
ctx: ConversionContext,
2525
target: Target,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from typing import Dict, Sequence, Tuple, Union
3+
4+
import torch
5+
from torch.fx.node import Argument, Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9+
from torch_tensorrt.fx.types import TRTTensor
10+
11+
from .converter_registry import dynamo_tensorrt_converter
12+
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
# TODO: expand the scope of this converter with aten.expand implementation
17+
def broadcast_checker(broadcast_node: torch.fx.Node) -> bool:
18+
# The current implementation of broadcast_in_dim can only handle unsqueeze
19+
return all(
20+
broadcast_node.args[1][i] == 1
21+
for i in range(len(broadcast_node.args[1]))
22+
if i not in broadcast_node.args[2]
23+
)
24+
25+
26+
@dynamo_tensorrt_converter(
27+
torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker
28+
) # type: ignore[misc]
29+
def aten_ops_broadcast_in_dim(
30+
ctx: ConversionContext,
31+
target: Target,
32+
args: Tuple[Argument, ...],
33+
kwargs: Dict[str, Argument],
34+
name: str,
35+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
36+
return impl.unsqueeze.broadcast_in_dim(
37+
ctx,
38+
target,
39+
SourceIR.PRIM,
40+
name,
41+
args[0],
42+
args[1],
43+
args[2],
44+
)

tests/py/dynamo/conversion/test_div_aten.py

+18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
56
from torch_tensorrt import Input
67

78
from .harness import DispatchTestCase
@@ -82,6 +83,23 @@ def forward(self, lhs_val):
8283
inputs,
8384
)
8485

86+
@parameterized.expand(
87+
[
88+
("2d", (2, 1)),
89+
("3d", (2, 1, 2)),
90+
]
91+
)
92+
def test_prims_div_tensor(self, _, shape):
93+
class div(nn.Module):
94+
def forward(self, lhs_val, rhs_val):
95+
return torch.ops.prims.div.default(lhs_val, rhs_val)
96+
97+
inputs = [torch.randn(shape), torch.randn(shape)]
98+
self.run_test(
99+
div(),
100+
inputs,
101+
)
102+
85103

86104
if __name__ == "__main__":
87105
run_tests()

tests/py/dynamo/conversion/test_sum_aten.py

+21
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,26 @@ def forward(self, x):
108108
)
109109

110110

111+
class TestPrimsSumConverter(DispatchTestCase):
112+
@parameterized.expand(
113+
[
114+
((3, 2, 4), [1]),
115+
((2, 1, 4, 5), [1, 2]),
116+
((2, 3, 4, 5), [0, 1, 2, 3]),
117+
((6, 7, 5, 4, 5), [1, 3, 4]),
118+
]
119+
)
120+
def test_sum_dim_sequence(self, input_shape, dim):
121+
class Sum(nn.Module):
122+
def forward(self, x):
123+
return torch.ops.prims.sum.default(x, dim)
124+
125+
inputs = [torch.randn(*input_shape)]
126+
self.run_test(
127+
Sum(),
128+
inputs,
129+
)
130+
131+
111132
if __name__ == "__main__":
112133
run_tests()

0 commit comments

Comments
 (0)