Skip to content

Commit f65340d

Browse files
authored
fix: Add decomposition for aten.addmm (#1953)
1 parent 3fc3c6d commit f65340d

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py

+9
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,14 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
5656
return x
5757

5858

59+
@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS)
60+
def addmm_replacement(
61+
input_: torch.Tensor, mat1: torch.Tensor, mat2: torch.Tensor, *, beta=1, alpha=1
62+
) -> torch.Tensor:
63+
return torch.add(
64+
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
65+
)
66+
67+
5968
def get_decompositions():
6069
return DECOMPOSITIONS

py/torch_tensorrt/dynamo/backend/test/test_lowering.py renamed to py/torch_tensorrt/dynamo/backend/test/test_decompositions.py

+70
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from utils import lower_graph_testing
33
from torch.testing._internal.common_utils import run_tests, TestCase
44
import torch
5+
from torch_tensorrt.dynamo import compile
6+
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
57

68

79
class TestLowering(TestCase):
@@ -109,6 +111,74 @@ def forward(self, x):
109111
f"The following expected ops were not encountered: {expected_ops_unseen}",
110112
)
111113

114+
def test_lowering_addmm(self):
115+
class AddMM(torch.nn.Module):
116+
def forward(self, x, y, z):
117+
return torch.addmm(x, y, z, beta=16, alpha=5)
118+
119+
# Operations expected to be included in the traced graph after decompositions
120+
expected_ops = {
121+
torch.ops.aten.add.Tensor,
122+
torch.ops.aten.mul.Tensor,
123+
torch.ops.aten.mm.default,
124+
}
125+
unexpected_ops = {torch.ops.aten.addmm.default}
126+
127+
inputs = [
128+
torch.rand(
129+
1,
130+
1,
131+
).cuda(),
132+
torch.rand(
133+
7,
134+
8,
135+
).cuda(),
136+
torch.rand(
137+
8,
138+
9,
139+
).cuda(),
140+
]
141+
142+
fx_graph = torch.fx.symbolic_trace(AddMM())
143+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
144+
fx_graph,
145+
inputs,
146+
expected_ops=expected_ops,
147+
unexpected_ops=unexpected_ops,
148+
min_block_size=1,
149+
)
150+
151+
self.assertEquals(
152+
len(unexpected_ops_seen),
153+
0,
154+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
155+
)
156+
157+
self.assertEquals(
158+
len(expected_ops_unseen),
159+
0,
160+
f"The following expected ops were not encountered: {expected_ops_unseen}",
161+
)
162+
163+
torch._dynamo.reset()
164+
165+
# Validate that the results between Torch and Torch-TRT are similar
166+
optimized_model = compile(
167+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
168+
)
169+
optimized_model_results = optimized_model(*inputs).detach().cpu()
170+
torch_model_results = fx_graph(*inputs).detach().cpu()
171+
172+
max_diff = float(
173+
torch.max(torch.abs(optimized_model_results - torch_model_results))
174+
)
175+
self.assertAlmostEqual(
176+
max_diff,
177+
0,
178+
DECIMALS_OF_AGREEMENT,
179+
f"AddMM TRT outputs don't match with the original model.",
180+
)
181+
112182

113183
if __name__ == "__main__":
114184
run_tests()

py/torch_tensorrt/dynamo/common_utils/test_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22

33
COSINE_THRESHOLD = 0.99
4+
DECIMALS_OF_AGREEMENT = 5
45

56

67
def cosine_similarity(gt_tensor, pred_tensor):

0 commit comments

Comments
 (0)