Skip to content

Commit 00d89fc

Browse files
committed
feat: support linear (fully connected layer) dynamo converter
refactor linear func
1 parent b1c7285 commit 00d89fc

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+19
Original file line numberDiff line numberDiff line change
@@ -879,3 +879,22 @@ def aten_ops_convolution(
879879
dilation=args[5],
880880
groups=args[8],
881881
)
882+
883+
884+
@dynamo_tensorrt_converter(torch.ops.aten.linear)
885+
def aten_ops_linear(
886+
network: TRTNetwork,
887+
target: Target,
888+
args: Tuple[Argument, ...],
889+
kwargs: Dict[str, Argument],
890+
name: str,
891+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
892+
return impl.linear.linear(
893+
network,
894+
target,
895+
SourceIR.ATEN,
896+
name,
897+
input=args[0],
898+
weight=args[1],
899+
bias=args_bounds_check(args, 2, None),
900+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
conv,
88
elementwise,
99
embedding,
10+
linear,
1011
matmul,
1112
normalization,
1213
permutation,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Optional, Union
2+
3+
import numpy as np
4+
import tensorrt as trt
5+
import torch
6+
from torch.fx.node import Target
7+
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor
9+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
10+
11+
12+
def linear(
13+
network: TRTNetwork,
14+
target: Union[Target, str],
15+
source_ir: Optional[SourceIR],
16+
name: str,
17+
input: TRTTensor,
18+
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
19+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
20+
) -> TRTTensor:
21+
# Process weight terms
22+
if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)):
23+
raise RuntimeError(
24+
f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
25+
)
26+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
27+
weight = get_trt_tensor(network, weight, f"{name}_weight")
28+
29+
# Process bias terms
30+
if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)):
31+
raise RuntimeError(
32+
f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
33+
)
34+
elif isinstance(bias, (torch.Tensor, np.ndarray)):
35+
bias = get_trt_tensor(network, bias, f"{name}_bias")
36+
37+
# add IMatrixMultiplyLayer
38+
out = impl.matmul.matrix_multiply(
39+
network,
40+
target,
41+
source_ir,
42+
name,
43+
input,
44+
weight,
45+
input_matrix_op=trt.MatrixOperation.NONE,
46+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
47+
)
48+
49+
if bias is not None:
50+
# add bias
51+
out = impl.elementwise.add(network, target, source_ir, name, out, bias)
52+
53+
return out

0 commit comments

Comments
 (0)