diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b0f718256f..f5d1620573 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -245,6 +245,7 @@ def aten_ops_hard_sigmoid( @dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.bmm.default) # type: ignore[misc] def aten_ops_matmul( network: TRTNetwork, target: Target, diff --git a/tests/py/dynamo/conversion/test_bmm.py b/tests/py/dynamo/conversion/test_bmm.py new file mode 100644 index 0000000000..391bd0bf89 --- /dev/null +++ b/tests/py/dynamo/conversion/test_bmm.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestBmmConverter(DispatchTestCase): + @parameterized.expand( + [ + ("10_3_5", (10, 3, 4), (10, 4, 5)), + ("1_10_1", (1, 10, 1), (1, 1, 1)), + ("1_1_1", (1, 1, 1), (1, 1, 1)), + ] + ) + def test_bmm(self, _, input_shape, mat2_shape): + class BMM(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, mat2): + return torch.bmm(input, mat2) + + inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)] + + self.run_test( + BMM(), + inputs, + disable_passes=True, + expected_ops={torch.ops.aten.bmm.default}, + ) + + +if __name__ == "__main__": + run_tests()