Skip to content

Commit 0daf301

Browse files
author
Wei
authored
Changes done internally at Facebook (#1390)
7ae5990ba20126da1e0a93ad0887cb1892ff48cd Janet Yang <[email protected]> Pass to remove for _validate_and_get_n_vectors d269be2fc7d84738a642d1d53eb44e6886a28d0c Alex Beloi <[email protected]> [fx] add deferred weights (xl_weight) and tracing for xl_embedding_bag 6f233bc9c72d90a908db0548c9d2dbe853895137 Alex Beloi <[email protected]> [fx] fix out of bounds indices/offsets for embedding_bag ops with xl_weight 3ca3b21c6a85ab9a6e9de503d0f13ee713a7b67c Janet Yang <[email protected]> Support div, torch.norm 52955d93d25e857510ed1b765220e8e5b0b0bb08 Janet Yang <[email protected]> Pass to replace sum(elmtwise(X))/numel(X) w/ mean(elmtwise(X)) 89c56ef76a7a329f244a013ac5ccb099cb00c3c0 Janet Yang <[email protected]> Support scalar clamp, fixes for nan_to_num and benchmark 48071d8da1dc66fffceb0b42ea386079f1fb9709 Wei Wei <[email protected]> [ads] bug fix in push_down_parrallel_split_ops afdc533da031a64e162bb08c8629ff38739e24f8 Wei Wei <[email protected]> [fx2trt] disable dispatch trace leaf node test 9905612fd8e6e2e79dc2f2bd1fa5b5d7fd5c98c3 Shirong Wu <[email protected]> Add number constrain for fuse group ln d160a7a5e554d37c142e13f100bf4d8739ced232 Wei Wei <[email protected]> add option to remove passes c22f691e6eae1b06ecd301eb6285b32d5dc9717c Mike Iovine <[email protected]> [fx2trt] Support dict inputs in acc tracer 8c05a3c57b1f5c63108b979ef8c61411525d0b1f Mike Iovine <[email protected]> [fx2trt] Support namedtuple access in acc tracer getattr ff2000594e3f3ff75e0074edf9c38b5609128bbd Janet Yang <[email protected]> Generalize remove split ops more 1580805d827eb40c941e769b0b99e7c6a3ed6f89 Wei Wei <[email protected]> [fx2trt] add reshape unit test d6a975462071a3747d18edcbe87a3b143b3ece88 Archie Sravankumar <[email protected]> Added FX tracing for `log_softmax` 6943ac0e322077b36a03c50c4c9065de6cd32837 Sungmin Cho <[email protected]> Add replace_mutable_op lower pass baab27b81b1275de92fdaf760a158ce951564d33 Donglin Xia <[email protected]> Register avg_pool3d for acc_op in acc_op.py ae4c4e2c3c18d78542140fcc30e1c24f7c647ef3 Wei Wei <[email protected]> [aten2trt] init check-in fc94c5e110d5552349b2634662eae41f9f0b8933 Wei Wei <[email protected]> [ads] fix a bug in fuse_parallel_linear 87ef03338c9a25c5a610a2eb590345e8935f8d75 Wei Wei <[email protected]> [aten2trt] add binary ops fca64a5b09749284fc6028b510078257fd4717b1 Shirong Wu <[email protected]> Fix dper pass 2bb168517ace7e638cffc7a241b1cbf528790b92 Mike Iovine <[email protected]> [fx2trt] Add acc normalization blocklist 8c912e085cf8722d572698286020ae1ce055023d Zhijing Li (Accelerator Enablement) <[email protected]> Skip unstable test_conv_add_standalone_module 137a3977ffeb03d0387e8a95ff2f32f3d15b3de8 Wei Wei <[email protected]> [aten2trt] resnet support f06174dbb190df4ea488ca99a81d4884b5ed3aa2 wwei6 <[email protected]> [fx2trt] compile 817c1f0b6278ce0ad04dd88d43d21e7390e3baea wwei6 <[email protected]> [aten2trt] init check-in 92ce42c16f34804584a7e553eddf897c9fa4f65e wwei6 <[email protected]> [aten2trt] binary op
1 parent 75fdbf0 commit 0daf301

34 files changed

+2640
-215
lines changed

py/torch_tensorrt/fx/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,20 @@ FX2TRT is merged as FX module in Torch-TensorRT
22

33
- The user guide is in [link](../../../docsrc/tutorials/getting_started_with_fx_path.rst#installation)
44
- The examples are moved to [link](../../../examples/fx)
5+
6+
* Method 1. Follow the instrucions for Torch-TensorRT
7+
* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path
8+
`
9+
$ conda create --name python_env python=3.8
10+
$ conda activate python_env
11+
# Recommend to install PyTorch 1.12 and later
12+
$ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly
13+
# Install TensorRT python package
14+
$ pip3 install nvidia-pyindex
15+
$ pip3 install nvidia-tensorrt==8.2.4.2
16+
$ git clone https://github.com/pytorch/TensorRT.git
17+
$ cd TensorRT/py && python setup.py install --fx-only && cd ..
18+
$ pyton -c "import torch_tensorrt.fx"
19+
# Test an example by
20+
$ python py/torch_tensorrt/fx/example/lower_example.py
21+
`

py/torch_tensorrt/fx/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .transformation import * # noqa: F401 F403
1414
from .quantization import * # noqa: F401 F403
1515
from .acc_ops_converters import * # noqa: F401 F403
16+
from .aten_ops_converters import * # noqa: F401 F403
1617

1718
TRT_LOGGER = trt.Logger()
1819
trt.init_libnvinfer_plugins(TRT_LOGGER, "")

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,63 @@
2121
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
2222

2323
from .converter_utils import * # noqa: F403
24-
24+
from torch_tensorrt.fx.passes.lower_basic_pass import (
25+
trt_transposed_linear,
26+
trt_transposed_matmul,
27+
)
2528

2629
_LOGGER: logging.Logger = logging.getLogger(__name__)
2730

2831

32+
@tensorrt_converter(trt_transposed_matmul)
33+
def trt_transposed_matmul_converter(network, target, args, kwargs, name):
34+
lhs, rhs, lhs_transposed, rhs_transposed = args
35+
36+
if isinstance(lhs, torch.nn.Parameter):
37+
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
38+
if isinstance(rhs, torch.nn.Parameter):
39+
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
40+
layer = network.add_matrix_multiply(
41+
lhs,
42+
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,
43+
rhs,
44+
trt.MatrixOperation.TRANSPOSE if rhs_transposed else trt.MatrixOperation.NONE,
45+
)
46+
set_layer_name(layer, target, name)
47+
return layer.get_output(0)
48+
49+
50+
@tensorrt_converter(trt_transposed_linear)
51+
def trt_transposed_linear_converter(network, target, args, kwargs, name):
52+
input, weight, bias = args
53+
54+
weight = get_trt_tensor(network, weight.t(), f"{name}_weight")
55+
bias = get_trt_tensor(network, bias.reshape(1, -1), f"{name}_bias")
56+
57+
input, weight = broadcast(
58+
network,
59+
input,
60+
weight,
61+
f"{input.name}_broadcast",
62+
f"{weight.name}_broadcast",
63+
)
64+
layer = network.add_matrix_multiply(
65+
input,
66+
trt.MatrixOperation.TRANSPOSE,
67+
weight,
68+
trt.MatrixOperation.NONE,
69+
)
70+
set_layer_name(layer, target, f"{name}_mm")
71+
return add_binary_elementwise_layer(
72+
network,
73+
layer.get_output(0),
74+
bias,
75+
trt.ElementWiseOperation.SUM,
76+
target,
77+
f"{name}_add",
78+
)
79+
80+
2981
@tensorrt_converter(acc_ops.conv1d)
3082
def acc_ops_conv1d(
3183
network: TRTNetwork,
@@ -1975,7 +2027,10 @@ def acc_ops_max_poolnd(
19752027
f"MaxPool2d received input {input_val} that is not part "
19762028
"of the TensorRT region!"
19772029
)
1978-
extend_len = 2 if target == acc_ops.max_pool2d else 3
2030+
if target not in (acc_ops.max_pool2d, acc_ops.max_pool3d):
2031+
extend_len = 2 if len(kwargs["kernel_size"]) == 2 else 3
2032+
else:
2033+
extend_len = 2 if target == acc_ops.max_pool2d else 3
19792034
kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len)
19802035
stride = extend_attr_to_tuple(kwargs["stride"], extend_len)
19812036
padding = extend_attr_to_tuple(kwargs["padding"], extend_len)
@@ -2259,8 +2314,11 @@ def acc_ops_adaptive_avg_poolnd(
22592314
f"AdaptiveAvgPool2d received input {input_val} that is not part "
22602315
"of the TensorRT region!"
22612316
)
2317+
if target not in (acc_ops.adaptive_avg_pool3d, acc_ops.adaptive_avg_pool2d):
2318+
extend_len = 2 if len(kwargs["output_size"]) == 2 else 3
2319+
else:
2320+
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
22622321

2263-
extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3
22642322
assert all(
22652323
input_val.shape[-(i + 1)] != -1 for i in range(extend_len)
22662324
), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."
@@ -2747,7 +2805,10 @@ def acc_ops_linear(
27472805

27482806
if isinstance(kwargs["weight"], torch.Tensor):
27492807
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
2750-
weight_op = trt.MatrixOperation.NONE
2808+
if target is not acc_ops.linear:
2809+
weight_op = trt.MatrixOperation.TRANSPOSE
2810+
else:
2811+
weight_op = trt.MatrixOperation.NONE
27512812
else:
27522813
assert isinstance(
27532814
kwargs["weight"], TRTTensor
@@ -2782,17 +2843,26 @@ def acc_ops_linear(
27822843
return res
27832844

27842845

2785-
def add_clamp(network, input, val, op):
2786-
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
2787-
acc_ops_clamp_tensor = (
2788-
val
2789-
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
2790-
.cpu()
2791-
.numpy()
2792-
)
2793-
acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor)
2794-
layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op)
2795-
2846+
def add_clamp(network, input, val, op, name):
2847+
if not len(input.shape):
2848+
# clamping scalar
2849+
acc_ops_clamp_trt = get_trt_tensor(
2850+
network,
2851+
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
2852+
f"{name}_clamp_{val}",
2853+
)
2854+
else:
2855+
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
2856+
acc_ops_clamp_tensor = (
2857+
val
2858+
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
2859+
.cpu()
2860+
.numpy()
2861+
)
2862+
acc_ops_clamp_trt = network.add_constant(
2863+
acc_ops_clamp_shape, acc_ops_clamp_tensor
2864+
).get_output(0)
2865+
layer = network.add_elementwise(input, acc_ops_clamp_trt, op)
27962866
return layer
27972867

27982868

@@ -2816,13 +2886,13 @@ def acc_ops_clamp(
28162886

28172887
if min_val is not None:
28182888
clamp_min_layer = add_clamp(
2819-
network, input_val, min_val, trt.ElementWiseOperation.MAX
2889+
network, input_val, min_val, trt.ElementWiseOperation.MAX, name
28202890
)
28212891
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
28222892
input_val = clamp_min_layer.get_output(0)
28232893
if max_val is not None:
28242894
clamp_max_layer = add_clamp(
2825-
network, input_val, max_val, trt.ElementWiseOperation.MIN
2895+
network, input_val, max_val, trt.ElementWiseOperation.MIN, name
28262896
)
28272897
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
28282898
input_val = clamp_max_layer.get_output(0)

0 commit comments

Comments
 (0)