Skip to content

Commit fc12fbd

Browse files
committed
feat: support bmm converter in dynamo
Signed-off-by: Bo Wang <[email protected]>
1 parent 670d2be commit fc12fbd

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def aten_ops_hard_sigmoid(
245245
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
246246
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
247247
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
248+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default) # type: ignore[misc]
248249
def aten_ops_matmul(
249250
network: TRTNetwork,
250251
target: Target,
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBmmConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("10_3_5", (10, 3, 4), (10, 4, 5)),
13+
("1_10_1", (1, 10, 1), (1, 1, 1)),
14+
("1_1_1", (1, 1, 1), (1, 1, 1)),
15+
]
16+
)
17+
def test_bmm(self, _, input_shape, mat2_shape):
18+
class BMM(nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(self, input, mat2):
23+
return torch.bmm(input, mat2)
24+
25+
inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)]
26+
27+
self.run_test(
28+
BMM(),
29+
inputs,
30+
disable_passes=True,
31+
expected_ops={torch.ops.aten.bmm.default},
32+
)
33+
34+
35+
if __name__ == "__main__":
36+
run_tests()

0 commit comments

Comments
 (0)