Skip to content

Commit 6cbb24e

Browse files
committed
perf: Add efficient attention lowering pass
1 parent d2e4f6d commit 6cbb24e

File tree

8 files changed

+257
-10
lines changed

8 files changed

+257
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+15
Original file line numberDiff line numberDiff line change
@@ -1500,3 +1500,18 @@ def aten_ops_max_pool(
15001500
dilation=args_bounds_check(args, 4, replacement=1),
15011501
ceil_mode=args_bounds_check(args, 5, replacement=False),
15021502
)
1503+
1504+
1505+
@dynamo_tensorrt_converter(
1506+
torch.nn.functional.scaled_dot_product_attention,
1507+
) # type: ignore[misc]
1508+
def tensorrt_scaled_dot_product_attention(
1509+
network: TRTNetwork,
1510+
target: Target,
1511+
args: Tuple[Argument, ...],
1512+
kwargs: Dict[str, Argument],
1513+
name: str,
1514+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1515+
return impl.attention.scaled_dot_product_attention(
1516+
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
1517+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import (
44
activation,
5+
attention,
56
cast,
67
condition,
78
conv,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import math
2+
from typing import Optional, Union
3+
4+
import tensorrt as trt
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo.conversion import impl
7+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
11+
def scaled_dot_product_attention(
12+
network: TRTNetwork,
13+
target: Union[Target, str],
14+
source_ir: Optional[SourceIR],
15+
name: str,
16+
query: TRTTensor,
17+
key: TRTTensor,
18+
value: TRTTensor,
19+
) -> TRTTensor:
20+
mm = impl.matmul.matrix_multiply(
21+
network,
22+
target,
23+
source_ir,
24+
name + "_mm",
25+
query,
26+
key,
27+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
28+
)
29+
div = impl.elementwise.div(
30+
network,
31+
target,
32+
source_ir,
33+
name + "_scale",
34+
mm,
35+
math.sqrt(query.shape[-1]),
36+
)
37+
softmax = impl.normalization.softmax(
38+
network, target, source_ir, name + "_softmax", div, -1
39+
)
40+
out = impl.matmul.matrix_multiply(
41+
network,
42+
target,
43+
source_ir,
44+
name + "_out",
45+
softmax,
46+
value,
47+
)
48+
49+
return out

py/torch_tensorrt/dynamo/lowering/_decompositions.py

-5
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ def inplace_op(*args, **kwargs): # type: ignore
8383
replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce)
8484

8585

86-
@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS)
87-
def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
88-
return torch.sqrt(torch.var(*args, **kwargs))
89-
90-
9186
@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS)
9287
def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore
9388
return torch.reciprocal(torch.sqrt(*args, **kwargs))

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .constant_folding import constant_fold
77
from .fuse_prims_broadcast import fuse_prims_broadcast
8+
from .lower_efficient_attention import lower_efficient_attention
89
from .pass_manager import DynamoPassManager
910
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1011
from .repair_input_as_output import repair_input_as_output
@@ -14,6 +15,7 @@
1415
remove_input_alias_fixing_clones,
1516
constant_fold,
1617
repair_input_as_output,
18+
lower_efficient_attention,
1719
fuse_prims_broadcast,
1820
]
1921
)

py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,6 @@ def fuse_prims_broadcast(
7777

7878
if modified_graph:
7979
gm = clean_up_graph_after_modifications(gm)
80-
logger.debug(f"Fused prims-broadcast paradigm:\n{gm.graph}")
80+
logger.debug(f"Graph after fusing prims-broadcast paradigm:\n{gm.graph}")
8181

8282
return gm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
import operator
3+
from typing import Callable, Sequence, Tuple
4+
5+
import torch
6+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
7+
clean_up_graph_after_modifications,
8+
get_tensor_placeholders,
9+
)
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def lower_efficient_attention(
15+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
16+
) -> torch.fx.GraphModule:
17+
"""Replace a specific version of scaled_dot_product_attention with an equivalent
18+
implementation which can be easily converted to TRT
19+
"""
20+
orig, replacement = efficient_attention_replacement()
21+
22+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
23+
gm = clean_up_graph_after_modifications(gm)
24+
logger.debug(
25+
f"Graph after lowering _scaled_dot_product_efficient_attention:\n{gm.graph}"
26+
)
27+
28+
return gm
29+
30+
31+
def efficient_attention_replacement() -> (
32+
Tuple[
33+
torch.fx.GraphModule,
34+
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
35+
]
36+
):
37+
"""Constructs the original and replacement functions for efficient attention"""
38+
39+
# Empty boilerplate function taking in three Tensors and returning one
40+
def boilerplate(
41+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
42+
) -> torch.Tensor:
43+
...
44+
45+
# Trace boilerplate function and extract placeholder and output nodes
46+
orig = torch.fx.symbolic_trace(boilerplate)
47+
q, k, v = get_tensor_placeholders(orig)
48+
output = [node for node in orig.graph.nodes if node.op == "output"][0]
49+
50+
# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
51+
# function and extract only the first element
52+
with orig.graph.inserting_before(output):
53+
att = orig.graph.call_function(
54+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
55+
args=(q, k, v, None, False),
56+
)
57+
out = orig.graph.call_function(
58+
operator.getitem,
59+
args=(att, 0),
60+
)
61+
62+
# Assign the output of the graph to be the single getitem output
63+
output.args = (out,)
64+
65+
orig.graph.lint()
66+
orig.recompile()
67+
68+
# Replacement graph consists of the functional version of scaled_dot_product_attention
69+
def replacement(
70+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
71+
) -> torch.Tensor:
72+
return torch.nn.functional.scaled_dot_product_attention(query, key, value)
73+
74+
return orig, replacement

tests/py/dynamo/lowering/test_aten_lowering_passes.py

+115-4
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9292

9393

9494
class TestPrimBroadcastFusion(TestCase):
95-
def test_input_as_output(self):
96-
class InputAsOutput(torch.nn.Module):
95+
def test_broadcast_fusion(self):
96+
class BroadcastFusion(torch.nn.Module):
9797
def forward(self, x):
9898
return torch.var_mean(x, keepdim=True)[1]
9999

@@ -104,7 +104,7 @@ def forward(self, x):
104104
).cuda(),
105105
]
106106

107-
fx_graph = torch.fx.symbolic_trace(InputAsOutput())
107+
fx_graph = torch.fx.symbolic_trace(BroadcastFusion())
108108
expected_ops = {torch.ops.aten.sum.dim_IntList}
109109
unexpected_ops = {torch.ops.aten.var.default, torch.ops.prims.var.default}
110110

@@ -151,7 +151,118 @@ def forward(self, x):
151151
max_diff,
152152
0,
153153
DECIMALS_OF_AGREEMENT,
154-
msg=f"InputAsOutput TRT outputs don't match with the original model.",
154+
msg=f"BroadcastFusion TRT outputs don't match with the original model.",
155+
)
156+
torch._dynamo.reset()
157+
158+
159+
class TestLowerEfficientAttention(TestCase):
160+
def test_lower_efficient_attention(self):
161+
class EfficientAttention(torch.nn.Module):
162+
def forward(self, q, k, v):
163+
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default(
164+
q, k, v, None, False
165+
)
166+
return attn[0]
167+
168+
inputs = [
169+
torch.rand(8, 4, 5, 4).cuda(),
170+
torch.rand(8, 4, 2, 4).cuda(),
171+
torch.rand(8, 4, 2, 4).cuda(),
172+
]
173+
174+
fx_graph = torch.fx.symbolic_trace(EfficientAttention())
175+
expected_ops = {torch.nn.functional.scaled_dot_product_attention}
176+
unexpected_ops = {
177+
torch.ops.aten._scaled_dot_product_efficient_attention.default
178+
}
179+
180+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
181+
fx_graph,
182+
inputs,
183+
expected_ops=expected_ops,
184+
unexpected_ops=unexpected_ops,
185+
min_block_size=1,
186+
)
187+
188+
self.assertEquals(
189+
len(unexpected_ops_seen),
190+
0,
191+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
192+
)
193+
194+
self.assertEquals(
195+
len(expected_ops_unseen),
196+
0,
197+
f"The following expected ops were not encountered: {expected_ops_unseen}",
198+
)
199+
torch._dynamo.reset()
200+
201+
# Validate that the results between Torch and Torch-TRT are similar
202+
optimized_model = torch_tensorrt.compile(
203+
fx_graph,
204+
"torch_compile",
205+
inputs,
206+
min_block_size=1,
207+
pass_through_build_failures=True,
208+
)
209+
optimized_model_results = torch.cat(
210+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
211+
)
212+
torch_model_results = torch.cat(
213+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
214+
)
215+
216+
max_diff = float(
217+
torch.max(torch.abs(optimized_model_results - torch_model_results))
218+
)
219+
self.assertAlmostEqual(
220+
max_diff,
221+
0,
222+
DECIMALS_OF_AGREEMENT,
223+
msg=f"EfficientAttention TRT outputs don't match with the original model.",
224+
)
225+
torch._dynamo.reset()
226+
227+
def test_efficient_attention_converter(self):
228+
class EfficientAttention(torch.nn.Module):
229+
def forward(self, q, k, v):
230+
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default(
231+
q, k, v, None, False
232+
)
233+
return attn[0]
234+
235+
inputs = [
236+
torch.rand(1, 3, 6, 4).cuda(),
237+
torch.rand(1, 3, 2, 4).cuda(),
238+
torch.rand(1, 3, 2, 4).cuda(),
239+
]
240+
241+
fx_graph = torch.fx.symbolic_trace(EfficientAttention())
242+
243+
# Validate that the results between Torch and Torch-TRT are similar
244+
optimized_model = torch_tensorrt.compile(
245+
fx_graph,
246+
"torch_compile",
247+
inputs,
248+
min_block_size=1,
249+
pass_through_build_failures=True,
250+
)
251+
optimized_model_results = torch.cat(
252+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
253+
)
254+
torch_model_results = torch.cat(
255+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
256+
)
257+
258+
max_diff = float(
259+
torch.max(torch.abs(optimized_model_results - torch_model_results))
260+
)
261+
self.assertAlmostEqual(
262+
max_diff,
263+
0,
264+
DECIMALS_OF_AGREEMENT,
265+
msg=f"EfficientAttention TRT outputs don't match with the original model.",
155266
)
156267
torch._dynamo.reset()
157268

0 commit comments

Comments
 (0)