21
21
from ..utils import get_dynamic_dims , torch_dtype_from_trt , torch_dtype_to_trt
22
22
23
23
from .converter_utils import * # noqa: F403
24
-
24
+ from torch_tensorrt .fx .passes .lower_basic_pass import (
25
+ trt_transposed_linear ,
26
+ trt_transposed_matmul ,
27
+ )
25
28
26
29
_LOGGER : logging .Logger = logging .getLogger (__name__ )
27
30
28
31
32
+ @tensorrt_converter (trt_transposed_matmul )
33
+ def trt_transposed_matmul_converter (network , target , args , kwargs , name ):
34
+ lhs , rhs , lhs_transposed , rhs_transposed = args
35
+
36
+ if isinstance (lhs , torch .nn .Parameter ):
37
+ lhs = get_trt_tensor (network , lhs , f"{ name } _lhs" )
38
+ if isinstance (rhs , torch .nn .Parameter ):
39
+ rhs = get_trt_tensor (network , rhs , f"{ name } _rhs" )
40
+ layer = network .add_matrix_multiply (
41
+ lhs ,
42
+ trt .MatrixOperation .TRANSPOSE if lhs_transposed else trt .MatrixOperation .NONE ,
43
+ rhs ,
44
+ trt .MatrixOperation .TRANSPOSE if rhs_transposed else trt .MatrixOperation .NONE ,
45
+ )
46
+ set_layer_name (layer , target , name )
47
+ return layer .get_output (0 )
48
+
49
+
50
+ @tensorrt_converter (trt_transposed_linear )
51
+ def trt_transposed_linear_converter (network , target , args , kwargs , name ):
52
+ input , weight , bias = args
53
+
54
+ weight = get_trt_tensor (network , weight .t (), f"{ name } _weight" )
55
+ bias = get_trt_tensor (network , bias .reshape (1 , - 1 ), f"{ name } _bias" )
56
+
57
+ input , weight = broadcast (
58
+ network ,
59
+ input ,
60
+ weight ,
61
+ f"{ input .name } _broadcast" ,
62
+ f"{ weight .name } _broadcast" ,
63
+ )
64
+ layer = network .add_matrix_multiply (
65
+ input ,
66
+ trt .MatrixOperation .TRANSPOSE ,
67
+ weight ,
68
+ trt .MatrixOperation .NONE ,
69
+ )
70
+ set_layer_name (layer , target , f"{ name } _mm" )
71
+ return add_binary_elementwise_layer (
72
+ network ,
73
+ layer .get_output (0 ),
74
+ bias ,
75
+ trt .ElementWiseOperation .SUM ,
76
+ target ,
77
+ f"{ name } _add" ,
78
+ )
79
+
80
+
29
81
@tensorrt_converter (acc_ops .conv1d )
30
82
def acc_ops_conv1d (
31
83
network : TRTNetwork ,
@@ -1975,7 +2027,10 @@ def acc_ops_max_poolnd(
1975
2027
f"MaxPool2d received input { input_val } that is not part "
1976
2028
"of the TensorRT region!"
1977
2029
)
1978
- extend_len = 2 if target == acc_ops .max_pool2d else 3
2030
+ if target not in (acc_ops .max_pool2d , acc_ops .max_pool3d ):
2031
+ extend_len = 2 if len (kwargs ["kernel_size" ]) == 2 else 3
2032
+ else :
2033
+ extend_len = 2 if target == acc_ops .max_pool2d else 3
1979
2034
kernel_size = extend_attr_to_tuple (kwargs ["kernel_size" ], extend_len )
1980
2035
stride = extend_attr_to_tuple (kwargs ["stride" ], extend_len )
1981
2036
padding = extend_attr_to_tuple (kwargs ["padding" ], extend_len )
@@ -2259,8 +2314,11 @@ def acc_ops_adaptive_avg_poolnd(
2259
2314
f"AdaptiveAvgPool2d received input { input_val } that is not part "
2260
2315
"of the TensorRT region!"
2261
2316
)
2317
+ if target not in (acc_ops .adaptive_avg_pool3d , acc_ops .adaptive_avg_pool2d ):
2318
+ extend_len = 2 if len (kwargs ["output_size" ]) == 2 else 3
2319
+ else :
2320
+ extend_len = 2 if target == acc_ops .adaptive_avg_pool2d else 3
2262
2321
2263
- extend_len = 2 if target == acc_ops .adaptive_avg_pool2d else 3
2264
2322
assert all (
2265
2323
input_val .shape [- (i + 1 )] != - 1 for i in range (extend_len )
2266
2324
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
@@ -2747,7 +2805,10 @@ def acc_ops_linear(
2747
2805
2748
2806
if isinstance (kwargs ["weight" ], torch .Tensor ):
2749
2807
weight = get_trt_tensor (network , kwargs ["weight" ].t (), f"{ name } _weight" )
2750
- weight_op = trt .MatrixOperation .NONE
2808
+ if target is not acc_ops .linear :
2809
+ weight_op = trt .MatrixOperation .TRANSPOSE
2810
+ else :
2811
+ weight_op = trt .MatrixOperation .NONE
2751
2812
else :
2752
2813
assert isinstance (
2753
2814
kwargs ["weight" ], TRTTensor
@@ -2782,17 +2843,26 @@ def acc_ops_linear(
2782
2843
return res
2783
2844
2784
2845
2785
- def add_clamp (network , input , val , op ):
2786
- acc_ops_clamp_shape = (1 ,) * len (input .shape ) # broadcast all dimensions
2787
- acc_ops_clamp_tensor = (
2788
- val
2789
- * torch .ones (acc_ops_clamp_shape , dtype = torch_dtype_from_trt (input .dtype ))
2790
- .cpu ()
2791
- .numpy ()
2792
- )
2793
- acc_ops_clamp_trt = network .add_constant (acc_ops_clamp_shape , acc_ops_clamp_tensor )
2794
- layer = network .add_elementwise (input , acc_ops_clamp_trt .get_output (0 ), op )
2795
-
2846
+ def add_clamp (network , input , val , op , name ):
2847
+ if not len (input .shape ):
2848
+ # clamping scalar
2849
+ acc_ops_clamp_trt = get_trt_tensor (
2850
+ network ,
2851
+ squeeze_left (torch .tensor ([val ], dtype = torch_dtype_from_trt (input .dtype ))),
2852
+ f"{ name } _clamp_{ val } " ,
2853
+ )
2854
+ else :
2855
+ acc_ops_clamp_shape = (1 ,) * len (input .shape ) # broadcast all dimensions
2856
+ acc_ops_clamp_tensor = (
2857
+ val
2858
+ * torch .ones (acc_ops_clamp_shape , dtype = torch_dtype_from_trt (input .dtype ))
2859
+ .cpu ()
2860
+ .numpy ()
2861
+ )
2862
+ acc_ops_clamp_trt = network .add_constant (
2863
+ acc_ops_clamp_shape , acc_ops_clamp_tensor
2864
+ ).get_output (0 )
2865
+ layer = network .add_elementwise (input , acc_ops_clamp_trt , op )
2796
2866
return layer
2797
2867
2798
2868
@@ -2816,13 +2886,13 @@ def acc_ops_clamp(
2816
2886
2817
2887
if min_val is not None :
2818
2888
clamp_min_layer = add_clamp (
2819
- network , input_val , min_val , trt .ElementWiseOperation .MAX
2889
+ network , input_val , min_val , trt .ElementWiseOperation .MAX , name
2820
2890
)
2821
2891
set_layer_name (clamp_min_layer , target , f"{ name } _clamp_min" )
2822
2892
input_val = clamp_min_layer .get_output (0 )
2823
2893
if max_val is not None :
2824
2894
clamp_max_layer = add_clamp (
2825
- network , input_val , max_val , trt .ElementWiseOperation .MIN
2895
+ network , input_val , max_val , trt .ElementWiseOperation .MIN , name
2826
2896
)
2827
2897
set_layer_name (clamp_max_layer , target , f"{ name } _clamp_max" )
2828
2898
input_val = clamp_max_layer .get_output (0 )
0 commit comments