Skip to content

Commit 349b08d

Browse files
committed
fix/feat: Add and repair multiple converters
- Focus on SD-performance-accelerating converters - Add test cases for converters to avoid regressions - Add prims sum converter
1 parent a7f9055 commit 349b08d

File tree

11 files changed

+254
-33
lines changed

11 files changed

+254
-33
lines changed

py/torch_tensorrt/dynamo/conversion/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from ._TRTInterpreter import * # noqa: F403
33
from .aten_ops_converters import * # noqa: F403
44
from .conversion import * # noqa: F403
5-
from .op_evaluators import * # noqa: F403
5+
from .ops_evaluators import * # noqa: F403
6+
from .prims_ops_converters import * # noqa: F403
67
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ def args_bounds_check(
2727
return args[i] if len(args) > i else replacement
2828

2929

30+
def get_ir(target: Target) -> SourceIR:
31+
target_module = getattr(target, "__module__", "None")
32+
if any(
33+
target_module.startswith(prefix)
34+
for prefix in ("torch.ops.prims", "torch._ops.prims")
35+
):
36+
return SourceIR.ATEN
37+
elif any(
38+
target_module.startswith(prefix)
39+
for prefix in ("torch.ops.prims", "torch._ops.prims")
40+
):
41+
return SourceIR.PRIM
42+
elif target_module.startswith("torch.nn"):
43+
return SourceIR.NN
44+
45+
return SourceIR.UNKNOWN
46+
47+
3048
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
3149
def aten_ops_batch_norm(
3250
ctx: ConversionContext,
@@ -651,23 +669,37 @@ def aten_ops_amax(
651669

652670
@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc]
653671
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc]
672+
@dynamo_tensorrt_converter(torch.ops.prims.sum.default) # type: ignore[misc]
654673
def aten_ops_sum(
655674
ctx: ConversionContext,
656675
target: Target,
657676
args: Tuple[Argument, ...],
658677
kwargs: Dict[str, Argument],
659678
name: str,
660679
) -> Union[TRTTensor, Sequence[TRTTensor]]:
661-
return impl.reduce.sum(
680+
sum_ = impl.reduce.sum(
662681
ctx,
663682
target,
664-
SourceIR.ATEN,
683+
get_ir(target),
665684
name,
666685
args[0],
667686
args_bounds_check(args, 1, replacement=None),
668687
args_bounds_check(args, 2, replacement=False),
669688
)
670689

690+
if kwargs.get("output_dtype", None) is not None:
691+
return impl.cast.to_copy(
692+
ctx,
693+
target,
694+
SourceIR.ATEN,
695+
name,
696+
sum_,
697+
kwargs["output_dtype"],
698+
force_layer=False,
699+
)
700+
else:
701+
return sum_
702+
671703

672704
@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
673705
def aten_ops_exp(
@@ -1166,6 +1198,7 @@ def aten_ops_sub(
11661198
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
11671199
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc]
11681200
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc]
1201+
@dynamo_tensorrt_converter(torch.ops.prims.div.default) # type: ignore[misc]
11691202
def aten_ops_div(
11701203
ctx: ConversionContext,
11711204
target: Target,
@@ -1179,7 +1212,7 @@ def aten_ops_div(
11791212
return impl.elementwise.div(
11801213
ctx,
11811214
target,
1182-
SourceIR.ATEN,
1215+
get_ir(target),
11831216
name,
11841217
args[0],
11851218
args[1],
@@ -1188,7 +1221,7 @@ def aten_ops_div(
11881221
return impl.elementwise.floor_divide(
11891222
ctx,
11901223
target,
1191-
SourceIR.ATEN,
1224+
get_ir(target),
11921225
name,
11931226
args[0],
11941227
args[1],
@@ -1197,7 +1230,7 @@ def aten_ops_div(
11971230
return impl.elementwise.trunc_div(
11981231
ctx,
11991232
target,
1200-
SourceIR.ATEN,
1233+
get_ir(target),
12011234
name,
12021235
args[0],
12031236
args[1],

py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import numpy as np
34
import tensorrt as trt
45
import torch
56
from torch.fx.node import Target
@@ -23,16 +24,6 @@ def where(
2324
other: TRTTensor,
2425
condition: TRTTensor,
2526
) -> TRTTensor:
26-
input_dim = len(tuple(input.shape))
27-
other_dim = len(tuple(other.shape))
28-
condition_dim = len(tuple(condition.shape))
29-
30-
if type(input) != TRTTensor:
31-
assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!"
32-
33-
if type(other) != TRTTensor:
34-
assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!"
35-
3627
if not (broadcastable(input, other)):
3728
assert "The two torch tensors should be broadcastable"
3829

@@ -49,33 +40,37 @@ def where(
4940
x_shape = list(input.shape)
5041
y_shape = list(other.shape)
5142
condition_shape = list(condition.shape)
43+
5244
output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape))
5345

5446
# expand shape
55-
if type(condition) != TRTTensor:
56-
assert condition.dtype == torch.bool, "condition dtype is not bool"
47+
if not isinstance(condition, TRTTensor):
48+
assert condition.dtype in (torch.bool, np.bool_), "condition dtype is not bool"
5749
if condition_shape != output_shape:
58-
condition.expand(output_shape)
59-
condition = condition.to(torch.int32)
60-
condition_const = get_trt_tensor(ctx, condition, f"{name}_condition")
61-
condition_layer = ctx.net.add_identity(condition_const)
62-
condition_layer.set_output_type(0, trt.bool)
63-
set_layer_name(condition_layer, target, f"{name}_condition")
64-
condition_val = condition_layer.get_output(0)
50+
condition = (
51+
condition.expand(output_shape)
52+
if isinstance(condition, torch.Tensor)
53+
else np.broadcast_to(condition, output_shape)
54+
)
55+
condition_val = get_trt_tensor(ctx, condition, f"{name}_condition")
6556
else:
6657
assert condition.dtype == trt.bool, "mask dtype is not bool!"
67-
if len(condition_shape) != condition_dim:
58+
if condition_shape != output_shape:
6859
condition_val = expand(
6960
ctx, target, source_ir, f"{name}_expand", condition, output_shape
7061
)
7162
else:
7263
condition_val = condition
7364

74-
if type(input) != TRTTensor:
65+
if not isinstance(input, TRTTensor):
7566
if x_shape != output_shape:
7667
# special case where 1 element in input
7768
if len(input.shape) == 0:
78-
input = input.unsqueeze(0)
69+
input = (
70+
input.unsqueeze(0)
71+
if isinstance(input, torch.Tensor)
72+
else np.expand_dims(input, axis=0)
73+
)
7974
input = input.expand(output_shape)
8075
x_val = get_trt_tensor(ctx, input, f"{name}_x")
8176
else:
@@ -85,11 +80,15 @@ def where(
8580
ctx, target, source_ir, f"{name}_x_expand", input, output_shape
8681
)
8782

88-
if type(other) != TRTTensor:
83+
if not isinstance(other, TRTTensor):
8984
if y_shape != output_shape:
9085
# special case where 1 element in other
9186
if len(other.shape) == 0:
92-
other = other.unsqueeze(0)
87+
other = (
88+
other.unsqueeze(0)
89+
if isinstance(other, torch.Tensor)
90+
else np.expand_dims(other, axis=0)
91+
)
9392
other = other.expand(output_shape)
9493
y_val = get_trt_tensor(ctx, other, f"{name}_y")
9594
else:

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def sum(
5151
):
5252
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
5353

54-
if dim is None:
54+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5555
dim = tuple(range(len(input_val.shape)))
5656
layer = ctx.net.add_reduce(
5757
input_val,

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, cast
1+
from typing import List, Optional, Sequence, cast
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -49,3 +49,42 @@ def unsqueeze(
4949
)
5050
set_layer_name(layer, target, name, source_ir)
5151
return layer.get_output(0)
52+
53+
54+
def broadcast_in_dim(
55+
ctx: ConversionContext,
56+
target: Target,
57+
source_ir: Optional[SourceIR],
58+
name: str,
59+
input_t: TRTTensor,
60+
shape: Sequence[int],
61+
broadcast_dimensions: Sequence[int],
62+
) -> TRTTensor:
63+
augmented_shape_list: List[Optional[int]] = list(shape)
64+
65+
# For each dimension being broadcasted, set the augmented shape to None
66+
for broadcast_dim in broadcast_dimensions:
67+
augmented_shape_list[broadcast_dim] = None
68+
69+
# TODO: Expand support to arbitrary broadcasts
70+
assert all(
71+
dim in (1, None) for dim in augmented_shape_list
72+
), "broadcast_in_dim currently only supports unsqueeze broadcasting"
73+
74+
# Unsqueeze the shape repeatedly to broadcast
75+
output = input_t
76+
for idx, x in enumerate(augmented_shape_list):
77+
# If the value is not None, that dimension is to be broadcasted
78+
if x is not None:
79+
output = unsqueeze(
80+
ctx,
81+
target,
82+
source_ir,
83+
name + f"_unsqueeze_for_broadcast_{idx}",
84+
output,
85+
idx,
86+
)
87+
88+
assert tuple(output.shape) == tuple(shape), "broadcast_in_dim shapes don't match"
89+
90+
return output

py/torch_tensorrt/dynamo/conversion/op_evaluators.py renamed to py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def getitem_validator(getitem_node: Node) -> bool:
1919

2020

2121
# TODO: Subsequent evaluators should be registered here with their own validators
22-
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
22+
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) # type: ignore[misc]
2323
def generic_evaluator(
2424
ctx: ConversionContext,
2525
target: Target,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from typing import Dict, Sequence, Tuple, Union
3+
4+
import torch
5+
from torch.fx.node import Argument, Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9+
from torch_tensorrt.fx.types import TRTTensor
10+
11+
from .converter_registry import dynamo_tensorrt_converter
12+
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
# TODO: expand the scope of this converter with aten.expand implementation
17+
def broadcast_checker(broadcast_node: torch.fx.Node) -> bool:
18+
# The current implementation of broadcast_in_dim can only handle unsqueeze
19+
return all(
20+
broadcast_node.args[1][i] == 1
21+
for i in range(len(broadcast_node.args[1]))
22+
if i not in broadcast_node.args[2]
23+
)
24+
25+
26+
@dynamo_tensorrt_converter(
27+
torch.ops.prims.broadcast_in_dim.default, capability_validator=broadcast_checker
28+
) # type: ignore[misc]
29+
def aten_ops_broadcast_in_dim(
30+
ctx: ConversionContext,
31+
target: Target,
32+
args: Tuple[Argument, ...],
33+
kwargs: Dict[str, Argument],
34+
name: str,
35+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
36+
return impl.unsqueeze.broadcast_in_dim(
37+
ctx,
38+
target,
39+
SourceIR.PRIM,
40+
name,
41+
args[0],
42+
args[1],
43+
args[2],
44+
)

tests/py/dynamo/conversion/test_div_aten.py

+18
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ def forward(self, lhs_val):
8282
expected_ops={torch.ops.aten.div.Tensor_mode},
8383
)
8484

85+
@parameterized.expand(
86+
[
87+
("2d", (2, 1)),
88+
("3d", (2, 1, 2)),
89+
]
90+
)
91+
def test_prims_div_tensor(self, _, shape):
92+
class div(nn.Module):
93+
def forward(self, lhs_val, rhs_val):
94+
return torch.ops.prims.div.default(lhs_val, rhs_val)
95+
96+
inputs = [torch.randn(shape), torch.randn(shape)]
97+
self.run_test(
98+
div(),
99+
inputs,
100+
expected_ops={torch.ops.prims.div.default},
101+
)
102+
85103

86104
if __name__ == "__main__":
87105
run_tests()

tests/py/dynamo/conversion/test_sum_aten.py

+22
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,27 @@ def forward(self, x):
113113
)
114114

115115

116+
class TestPrimsSumConverter(DispatchTestCase):
117+
@parameterized.expand(
118+
[
119+
((3, 2, 4), [1]),
120+
((2, 1, 4, 5), [1, 2]),
121+
((2, 3, 4, 5), [0, 1, 2, 3]),
122+
((6, 7, 5, 4, 5), [1, 3, 4]),
123+
]
124+
)
125+
def test_sum_dim_sequence(self, input_shape, dim):
126+
class Sum(nn.Module):
127+
def forward(self, x):
128+
return torch.ops.prims.sum.default(x, dim)
129+
130+
inputs = [torch.randn(*input_shape)]
131+
self.run_test(
132+
Sum(),
133+
inputs,
134+
expected_ops={torch.ops.prims.sum.default},
135+
)
136+
137+
116138
if __name__ == "__main__":
117139
run_tests()

0 commit comments

Comments
 (0)