Skip to content

Commit b2ac5f0

Browse files
committed
feat: support conv dynamo converter
1 parent 91fcea4 commit b2ac5f0

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+31
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,34 @@ def aten_ops_clone(
420420
name,
421421
args[0],
422422
)
423+
424+
425+
@dynamo_tensorrt_converter(torch.ops.aten.convolution.default)
426+
def aten_ops_convolution(
427+
network: TRTNetwork,
428+
target: Target,
429+
args: Tuple[Argument, ...],
430+
kwargs: Dict[str, Argument],
431+
name: str,
432+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
433+
# we do not handle transposed.
434+
if args[6] is True:
435+
raise RuntimeError(f"Target {target} does not support `transposed=True` ")
436+
# we do not handle output_padding.
437+
if args[7] not in ([0], [0, 0], [0, 0, 0]):
438+
raise RuntimeError(f"Target {target} has non-0 output_padding")
439+
440+
return impl.conv.convNd(
441+
network,
442+
target,
443+
source_ir=SourceIR.ATEN,
444+
name=name,
445+
is_conv1d=len(args[3]) == 1,
446+
input=args[0],
447+
weight=args[1],
448+
bias=args[2],
449+
stride=args[3],
450+
padding=args[4],
451+
dilation=args[5],
452+
groups=args[8],
453+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
activation,
55
cast,
66
condition,
7+
conv,
78
elementwise,
89
embedding,
910
matmul,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import numpy as np
4+
5+
# @manual=//deeplearning/trt/python:py_tensorrt
6+
import tensorrt as trt
7+
import torch
8+
from torch.fx.node import Target
9+
from torch_tensorrt.dynamo.conversion import aten_ops_converters
10+
from torch_tensorrt.fx.converters.converter_utils import (
11+
SourceIR,
12+
extend_attr_to_tuple,
13+
get_dyn_range,
14+
get_trt_tensor,
15+
has_dynamic_shape,
16+
mark_as_int8_layer,
17+
set_layer_name,
18+
to_numpy,
19+
)
20+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
21+
22+
23+
def convNd(
24+
network: TRTNetwork,
25+
target: Union[Target, str],
26+
source_ir: Optional[SourceIR],
27+
name: str,
28+
is_conv1d: bool,
29+
input: TRTTensor,
30+
weight: Union[TRTTensor, torch.Tensor],
31+
bias: Optional[Union[TRTTensor, torch.Tensor]],
32+
stride: Optional[Union[int, Sequence[int]]],
33+
padding: Optional[Union[int, Sequence[int]]],
34+
dilation: Optional[Union[int, Sequence[int]]],
35+
groups: Optional[int],
36+
scale: Optional[Union[torch.Tensor, float]] = None,
37+
zero_point: Optional[Union[torch.Tensor, float]] = None,
38+
) -> TRTTensor:
39+
if has_dynamic_shape(input.shape):
40+
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."
41+
42+
if is_conv1d:
43+
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
44+
input = aten_ops_converters.aten_ops_unsqueeze(
45+
network, target, (input, -1), {}, name + "_unsqueeze"
46+
)
47+
48+
# Process bias terms
49+
if isinstance(bias, torch.Tensor):
50+
# Transform the bias constant into a Numpy array
51+
bias = to_numpy(bias)
52+
53+
elif isinstance(bias, TRTTensor):
54+
bias = get_trt_tensor(network, bias, f"{name}_bias")
55+
56+
elif bias is not None:
57+
raise RuntimeError(
58+
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
59+
)
60+
61+
# Process weight terms
62+
if network.has_explicit_precision or isinstance(weight, TRTTensor):
63+
weight = get_trt_tensor(network, weight, f"{name}_weight")
64+
# Append new dimension (unsqueeze) if the convolution is 1d
65+
if is_conv1d:
66+
weight = aten_ops_converters.aten_ops_unsqueeze(
67+
network, target, (weight, -1), {}, name + "_unsqueeze_weight"
68+
)
69+
70+
elif isinstance(weight, torch.Tensor):
71+
# Transform the weight constant into a Numpy array
72+
weight = to_numpy(weight)
73+
74+
# Append new dimension (unsqueeze) if the convolution is 1d
75+
if is_conv1d:
76+
weight = np.expand_dims(weight, -1)
77+
78+
else:
79+
raise RuntimeError(
80+
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
81+
)
82+
83+
# add conv layer
84+
conv_layer = network.add_convolution_nd(
85+
input=input,
86+
num_output_maps=weight.shape[0],
87+
kernel_shape=weight.shape[2:],
88+
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
89+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
90+
)
91+
92+
# If the weight is a TRTTensor, set it as an input of the layer
93+
if isinstance(weight, TRTTensor):
94+
conv_layer.set_input(1, weight)
95+
96+
# If the bias is a TRTTensor, set it as an input of the layer
97+
if isinstance(bias, TRTTensor):
98+
conv_layer.set_input(2, bias)
99+
100+
# Expand parameters manually for Conv1D computations
101+
if is_conv1d:
102+
padding = tuple(padding) + (0,)
103+
stride = extend_attr_to_tuple(stride, 2)
104+
dilation = extend_attr_to_tuple(dilation, 2)
105+
106+
set_layer_name(conv_layer, target, name, source_ir)
107+
108+
# Set relevant attributes of convolution layer
109+
conv_layer.padding_nd = padding
110+
conv_layer.stride_nd = stride
111+
conv_layer.dilation_nd = dilation
112+
113+
if groups is not None:
114+
conv_layer.num_groups = groups
115+
116+
# Handle quantization cases
117+
if scale is not None and zero_point is not None:
118+
# Assume the dtype of activation is torch.quint8
119+
mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8))
120+
121+
result = conv_layer.get_output(0)
122+
123+
if is_conv1d:
124+
# Apply a squeeze operation to transform the conv2d problem back into conv1d
125+
result = aten_ops_converters.aten_ops_squeeze(
126+
network, target, (result, -1), {}, name + "_squeeze"
127+
)
128+
129+
return result

0 commit comments

Comments
 (0)