Skip to content

Commit 9f579d8

Browse files
committed
feat: support bmm converter in dynamo
Signed-off-by: Bo Wang <[email protected]>
1 parent 670d2be commit 9f579d8

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def aten_ops_hard_sigmoid(
245245
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
246246
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
247247
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
248+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default) # type: ignore[misc]
248249
def aten_ops_matmul(
249250
network: TRTNetwork,
250251
target: Target,
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestBmmConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("10_3_5", (10, 3, 4), (10, 4, 5)),
13+
]
14+
)
15+
def test_bmm(self, _, input_shape, mat2_shape):
16+
class BMM(nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, input, mat2):
21+
return torch.bmm(input, mat2)
22+
23+
inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)]
24+
25+
self.run_test(
26+
BMM(),
27+
inputs,
28+
disable_passes=True,
29+
expected_ops={torch.ops.aten.bmm.default},
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
run_tests()

0 commit comments

Comments
 (0)