diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 792f58955b..9fcf959346 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1289,3 +1289,22 @@ def aten_ops_convolution( dilation=args[5], groups=args[8], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.linear.default) +def aten_ops_linear( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.linear.linear( + network, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args_bounds_check(args, 2, None), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 4ee7fd2bed..db7c877e8f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -7,6 +7,7 @@ conv, elementwise, embedding, + linear, matmul, normalization, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py new file mode 100644 index 0000000000..0a98087bce --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -0,0 +1,53 @@ +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def linear( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: Union[TRTTensor, torch.Tensor, np.ndarray], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], +) -> TRTTensor: + # Process weight terms + if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)): + raise RuntimeError( + f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," + ) + elif isinstance(weight, (torch.Tensor, np.ndarray)): + weight = get_trt_tensor(network, weight, f"{name}_weight") + + # Process bias terms + if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)): + raise RuntimeError( + f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," + ) + elif isinstance(bias, (torch.Tensor, np.ndarray)): + bias = get_trt_tensor(network, bias, f"{name}_bias") + + # add IMatrixMultiplyLayer + out = impl.matmul.matrix_multiply( + network, + target, + source_ir, + name, + input, + weight, + input_matrix_op=trt.MatrixOperation.NONE, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + if bias is not None: + # add bias + out = impl.elementwise.add(network, target, source_ir, name, out, bias) + + return out