Skip to content

Commit d43c5bb

Browse files
authored
feat: support conv dynamo converter (#2252)
1 parent 332295b commit d43c5bb

File tree

5 files changed

+189
-8
lines changed

5 files changed

+189
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+30
Original file line numberDiff line numberDiff line change
@@ -843,3 +843,33 @@ def aten_ops_isinf(
843843
name,
844844
args[0],
845845
)
846+
847+
848+
def conv_param_validator(conv_node: Node) -> bool:
849+
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
850+
851+
852+
@dynamo_tensorrt_converter(
853+
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
854+
)
855+
def aten_ops_convolution(
856+
network: TRTNetwork,
857+
target: Target,
858+
args: Tuple[Argument, ...],
859+
kwargs: Dict[str, Argument],
860+
name: str,
861+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
862+
return impl.conv.convNd(
863+
network,
864+
target,
865+
source_ir=SourceIR.ATEN,
866+
name=name,
867+
is_conv1d=len(args[3]) == 1,
868+
input=args[0],
869+
weight=args[1],
870+
bias=args[2],
871+
stride=args[3],
872+
padding=args[4],
873+
dilation=args[5],
874+
groups=args[8],
875+
)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import logging
33
import re
4-
from typing import List, Optional
4+
from typing import Any, List, Optional, Tuple
55

66
import tensorrt as trt
77
import torch
@@ -164,3 +164,27 @@ def broadcastable(
164164
get_axes_for_reduce_op = functools.partial(
165165
get_axes_for_reduce_op, has_implicit_batch_dimension=False
166166
)
167+
168+
169+
def extend_attr_to_tuple(
170+
val: Any,
171+
num_elem: int,
172+
) -> Tuple[Any, ...]:
173+
"""
174+
If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by
175+
replicating `val` `num_elem` times.
176+
177+
Args:
178+
val (Any): Value that we want to process.
179+
180+
Returns:
181+
A tuple.
182+
"""
183+
if not isinstance(val, (tuple, list)):
184+
val = (val,) * num_elem
185+
elif len(val) == 1:
186+
val = (val[0],) * num_elem
187+
188+
if isinstance(val, list):
189+
val = tuple(val)
190+
return val

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
activation,
55
cast,
66
condition,
7+
conv,
78
elementwise,
89
embedding,
910
matmul,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import numpy as np
4+
5+
# @manual=//deeplearning/trt/python:py_tensorrt
6+
import tensorrt as trt
7+
import torch
8+
from torch.fx.node import Target
9+
from torch_tensorrt.dynamo.conversion import impl
10+
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
11+
from torch_tensorrt.fx.converters.converter_utils import (
12+
SourceIR,
13+
get_dyn_range,
14+
get_trt_tensor,
15+
has_dynamic_shape,
16+
mark_as_int8_layer,
17+
set_layer_name,
18+
to_numpy,
19+
)
20+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
21+
22+
23+
def convNd(
24+
network: TRTNetwork,
25+
target: Union[Target, str],
26+
source_ir: Optional[SourceIR],
27+
name: str,
28+
is_conv1d: bool,
29+
input: TRTTensor,
30+
weight: Union[TRTTensor, torch.Tensor],
31+
bias: Optional[Union[TRTTensor, torch.Tensor]],
32+
stride: Optional[Union[int, Sequence[int]]],
33+
padding: Optional[Union[int, Sequence[int]]],
34+
dilation: Optional[Union[int, Sequence[int]]],
35+
groups: Optional[int],
36+
scale: Optional[Union[torch.Tensor, float]] = None,
37+
zero_point: Optional[Union[torch.Tensor, float]] = None,
38+
) -> TRTTensor:
39+
if has_dynamic_shape(input.shape):
40+
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."
41+
42+
if is_conv1d:
43+
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
44+
input = impl.unsqueeze.unsqueeze(
45+
network, target, source_ir, name + "_unsqueeze_conv1d", input, -1
46+
)
47+
48+
# Process bias terms
49+
if isinstance(bias, (torch.Tensor, np.ndarray)):
50+
# Transform the bias constant into a Numpy array
51+
bias = to_numpy(bias)
52+
53+
elif isinstance(bias, TRTTensor):
54+
bias = get_trt_tensor(network, bias, f"{name}_bias")
55+
56+
elif bias is not None:
57+
raise RuntimeError(
58+
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
59+
)
60+
61+
# Process weight terms
62+
if network.has_explicit_precision or isinstance(weight, TRTTensor):
63+
weight = get_trt_tensor(network, weight, f"{name}_weight")
64+
# Append new dimension (unsqueeze) if the convolution is 1d
65+
if is_conv1d:
66+
input = impl.unsqueeze.unsqueeze(
67+
network, target, source_ir, name + "_unsqueeze_weight", weight, -1
68+
)
69+
70+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
71+
# Transform the weight constant into a Numpy array
72+
weight = to_numpy(weight)
73+
74+
# Append new dimension (unsqueeze) if the convolution is 1d
75+
if is_conv1d:
76+
weight = np.expand_dims(weight, -1)
77+
78+
else:
79+
raise RuntimeError(
80+
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
81+
)
82+
83+
# add conv layer
84+
conv_layer = network.add_convolution_nd(
85+
input=input,
86+
num_output_maps=weight.shape[0],
87+
kernel_shape=weight.shape[2:],
88+
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
89+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
90+
)
91+
92+
# If the weight is a TRTTensor, set it as an input of the layer
93+
if isinstance(weight, TRTTensor):
94+
conv_layer.set_input(1, weight)
95+
96+
# If the bias is a TRTTensor, set it as an input of the layer
97+
if isinstance(bias, TRTTensor):
98+
conv_layer.set_input(2, bias)
99+
100+
# Expand parameters manually for Conv1D computations
101+
if is_conv1d:
102+
padding = tuple(padding) + (0,)
103+
stride = extend_attr_to_tuple(stride, 2)
104+
dilation = extend_attr_to_tuple(dilation, 2)
105+
106+
set_layer_name(conv_layer, target, name, source_ir)
107+
108+
# Set relevant attributes of convolution layer
109+
conv_layer.padding_nd = padding
110+
conv_layer.stride_nd = stride
111+
conv_layer.dilation_nd = dilation
112+
113+
if groups is not None:
114+
conv_layer.num_groups = groups
115+
116+
# Handle quantization cases
117+
if scale is not None and zero_point is not None:
118+
# Assume the dtype of activation is torch.quint8
119+
mark_as_int8_layer(conv_layer, get_dyn_range(scale, zero_point, torch.quint8))
120+
121+
result = conv_layer.get_output(0)
122+
123+
if is_conv1d:
124+
# Apply a squeeze operation to transform the conv2d problem back into conv1d
125+
result = impl.squeeze.squeeze(
126+
network, target, source_ir, name + "_squeeze_conv1d", result, -1
127+
)
128+
129+
return result

py/torch_tensorrt/dynamo/conversion/impl/squeeze.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ def squeeze(
1818
input: TRTTensor,
1919
dim: Optional[Any] = None,
2020
) -> TRTTensor:
21-
if not isinstance(input, TRTTensor):
22-
raise RuntimeError(
23-
f"squeeze received input {input} that is not part "
24-
"of the TensorRT region!"
25-
)
2621
dims = []
2722
if dim is not None:
2823
if isinstance(dim, int):
@@ -35,6 +30,7 @@ def squeeze(
3530
# dim, which is a very rare case. For now we just claim not supporting dim=None.
3631
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."
3732

33+
new_dims = []
3834
for dim in dims:
3935
dim = get_positive_dim(
4036
dim,
@@ -48,13 +44,14 @@ def squeeze(
4844
assert (
4945
len(get_dynamic_dims(input.shape)) <= 1
5046
), "Currently more than one dynamic dim for input to squeeze is not supported."
47+
new_dims.append(dim)
5148

5249
output_shape = []
5350
for i, s in enumerate(input.shape):
54-
if (i in dims) and s == 1:
51+
if (i in new_dims) and s == 1:
5552
continue
5653
output_shape.append(s)
5754
layer = network.add_shuffle(input)
5855
layer.reshape_dims = tuple(output_shape)
59-
set_layer_name(layer, target, name)
56+
set_layer_name(layer, target, name, source_ir)
6057
return layer.get_output(0)

0 commit comments

Comments
 (0)