|
2 | 2 | from utils import lower_graph_testing
|
3 | 3 | from torch.testing._internal.common_utils import run_tests, TestCase
|
4 | 4 | import torch
|
| 5 | +from torch_tensorrt.dynamo import compile |
| 6 | +from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT |
5 | 7 |
|
6 | 8 |
|
7 | 9 | class TestLowering(TestCase):
|
@@ -109,6 +111,74 @@ def forward(self, x):
|
109 | 111 | f"The following expected ops were not encountered: {expected_ops_unseen}",
|
110 | 112 | )
|
111 | 113 |
|
| 114 | + def test_lowering_addmm(self): |
| 115 | + class AddMM(torch.nn.Module): |
| 116 | + def forward(self, x, y, z): |
| 117 | + return torch.addmm(x, y, z, beta=16, alpha=5) |
| 118 | + |
| 119 | + # Operations expected to be included in the traced graph after decompositions |
| 120 | + expected_ops = { |
| 121 | + torch.ops.aten.add.Tensor, |
| 122 | + torch.ops.aten.mul.Tensor, |
| 123 | + torch.ops.aten.mm.default, |
| 124 | + } |
| 125 | + unexpected_ops = {torch.ops.aten.addmm.default} |
| 126 | + |
| 127 | + inputs = [ |
| 128 | + torch.rand( |
| 129 | + 1, |
| 130 | + 1, |
| 131 | + ).cuda(), |
| 132 | + torch.rand( |
| 133 | + 7, |
| 134 | + 8, |
| 135 | + ).cuda(), |
| 136 | + torch.rand( |
| 137 | + 8, |
| 138 | + 9, |
| 139 | + ).cuda(), |
| 140 | + ] |
| 141 | + |
| 142 | + fx_graph = torch.fx.symbolic_trace(AddMM()) |
| 143 | + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( |
| 144 | + fx_graph, |
| 145 | + inputs, |
| 146 | + expected_ops=expected_ops, |
| 147 | + unexpected_ops=unexpected_ops, |
| 148 | + min_block_size=1, |
| 149 | + ) |
| 150 | + |
| 151 | + self.assertEquals( |
| 152 | + len(unexpected_ops_seen), |
| 153 | + 0, |
| 154 | + f"The following unexpected ops were encountered: {unexpected_ops_seen}", |
| 155 | + ) |
| 156 | + |
| 157 | + self.assertEquals( |
| 158 | + len(expected_ops_unseen), |
| 159 | + 0, |
| 160 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 161 | + ) |
| 162 | + |
| 163 | + torch._dynamo.reset() |
| 164 | + |
| 165 | + # Validate that the results between Torch and Torch-TRT are similar |
| 166 | + optimized_model = compile( |
| 167 | + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True |
| 168 | + ) |
| 169 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 170 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 171 | + |
| 172 | + max_diff = float( |
| 173 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 174 | + ) |
| 175 | + self.assertAlmostEqual( |
| 176 | + max_diff, |
| 177 | + 0, |
| 178 | + DECIMALS_OF_AGREEMENT, |
| 179 | + f"AddMM TRT outputs don't match with the original model.", |
| 180 | + ) |
| 181 | + |
112 | 182 |
|
113 | 183 | if __name__ == "__main__":
|
114 | 184 | run_tests()
|
0 commit comments