Skip to content

Commit acf6944

Browse files
committed
converter: support bmm converter for dynamo
1 parent e77e445 commit acf6944

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+11
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,17 @@ def aten_ops_matmul(
182182
network, target, SourceIR.ATEN, name, args[0], args[1]
183183
)
184184

185+
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default)
186+
def aten_ops_bmm(
187+
network: TRTNetwork,
188+
target: Target,
189+
args: Tuple[Argument, ...],
190+
kwargs: Dict[str, Argument],
191+
name: str,
192+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
193+
return impl.matmul.bmm(
194+
network, target, SourceIR.ATEN, name, args[0], args[1]
195+
)
185196

186197
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
187198
def aten_ops_layernorm(

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

+34
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,37 @@ def matrix_multiply(
4848
layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
4949
set_layer_name(layer, target, name)
5050
return layer.get_output(0)
51+
52+
def bmm(
53+
network: TRTNetwork,
54+
target: Target,
55+
source_ir: Optional[SourceIR],
56+
name: str,
57+
input: TRTTensor,
58+
other: TRTTensor,
59+
) -> TRTTensor:
60+
if not isinstance(input, trt.tensorrt.ITensor):
61+
input = get_trt_tensor(network, input, f"{name}_input")
62+
if not isinstance(other, trt.tensorrt.ITensor):
63+
other = get_trt_tensor(
64+
network,
65+
other,
66+
f"{name}_other",
67+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
68+
)
69+
70+
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
71+
preset_diff = 0
72+
73+
if len(input.shape) != 3:
74+
raise RuntimeError(f"Expected 3-dimensional tensor, but got ")
75+
76+
if len(other.shape) != 3:
77+
raise RuntimeError(f"Expected 3-dimensional tensor, but got")
78+
79+
if (input.shape[0] != other.shape[0]):
80+
raise RuntimeError("expected input tensors to have same batch size.")
81+
82+
layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
83+
set_layer_name(layer, target, name)
84+
return layer.get_output(0)
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
from torch_tensorrt import Input
7+
8+
class TestBmmConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
("10_3_5", (10, 3, 4), (9, 4, 5)),
12+
]
13+
)
14+
def test_bmm(self, _, input_shape, mat2_shape):
15+
class BMM(nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
def forward(self, input, mat2):
20+
return torch.bmm(input, mat2)
21+
22+
inputs = [torch.randn(*input_shape), torch.randn(*mat2_shape)]
23+
24+
25+
self.run_test(
26+
BMM(),
27+
inputs,
28+
expected_ops={},
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
run_tests()

0 commit comments

Comments
 (0)