Skip to content

Commit 8c92918

Browse files
authored
feat: support linear (fully connected layer) dynamo converter (#2253)
1 parent c377c48 commit 8c92918

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+19
Original file line numberDiff line numberDiff line change
@@ -1289,3 +1289,22 @@ def aten_ops_convolution(
12891289
dilation=args[5],
12901290
groups=args[8],
12911291
)
1292+
1293+
1294+
@dynamo_tensorrt_converter(torch.ops.aten.linear.default)
1295+
def aten_ops_linear(
1296+
network: TRTNetwork,
1297+
target: Target,
1298+
args: Tuple[Argument, ...],
1299+
kwargs: Dict[str, Argument],
1300+
name: str,
1301+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1302+
return impl.linear.linear(
1303+
network,
1304+
target,
1305+
SourceIR.ATEN,
1306+
name,
1307+
input=args[0],
1308+
weight=args[1],
1309+
bias=args_bounds_check(args, 2, None),
1310+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
conv,
88
elementwise,
99
embedding,
10+
linear,
1011
matmul,
1112
normalization,
1213
permutation,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
import torch
6+
from torch.fx.node import Target
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor
9+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
10+
11+
12+
def linear(
13+
network: TRTNetwork,
14+
target: Union[Target, str],
15+
source_ir: Optional[SourceIR],
16+
name: str,
17+
input: TRTTensor,
18+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
19+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
20+
) -> TRTTensor:
21+
# Process weight terms
22+
if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)):
23+
raise RuntimeError(
24+
f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
25+
)
26+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
27+
weight = get_trt_tensor(network, weight, f"{name}_weight")
28+
29+
# Process bias terms
30+
if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)):
31+
raise RuntimeError(
32+
f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
33+
)
34+
elif isinstance(bias, (torch.Tensor, np.ndarray)):
35+
bias = get_trt_tensor(network, bias, f"{name}_bias")
36+
37+
# add IMatrixMultiplyLayer
38+
out = impl.matmul.matrix_multiply(
39+
network,
40+
target,
41+
source_ir,
42+
name,
43+
input,
44+
weight,
45+
input_matrix_op=trt.MatrixOperation.NONE,
46+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
47+
)
48+
49+
if bias is not None:
50+
# add bias
51+
out = impl.elementwise.add(network, target, source_ir, name, out, bias)
52+
53+
return out

0 commit comments

Comments
 (0)