From 815751b1c1491d41fbdcb9822454de62efe7912b Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Tue, 19 Mar 2024 18:10:14 -0700 Subject: [PATCH 01/20] fix: FakeTensors appearing in `get_attr` calls (#2669) --- .../lowering/passes/constant_folding.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 81a9d76d6e..76e79ac100 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -11,12 +11,9 @@ # Modify import location of utilities based on Torch version if version.parse(sanitized_torch_version()) < version.parse("2.1.1"): - from torch._inductor.freezing import ConstantFolder, replace_node_with_constant + from torch._inductor.freezing import ConstantFolder else: - from torch._inductor.constant_folding import ( - ConstantFolder, - replace_node_with_constant, - ) + from torch._inductor.constant_folding import ConstantFolder logger = logging.getLogger(__name__) @@ -36,7 +33,9 @@ def constant_fold( cf.run() for node, constant in cf.node_replacements.items(): - replace_node_with_constant(gm, node, constant) + replace_node_with_constant( + gm, node, torch.nn.Parameter(constant.cuda(), requires_grad=False) + ) erased_params = [] for node in gm.graph.nodes: @@ -55,6 +54,40 @@ def constant_fold( return gm +def replace_node_with_constant( + gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor +) -> None: + """Adapted from: + https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43 + + Modified to register parameters, instead of buffers for frozen constants + """ + g = gm.graph + + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 + + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 + + gm._frozen_param_count = i + 1 + + with g.inserting_before(node): + new_input_node = g.create_node("get_attr", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + + # Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_parameter(qualname, constant) + setattr(gm, qualname, constant) + + # TODO: Delete this class when the following code is fixed in nightly: # https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63 class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc] From 5930e961546046607a9b8bb43a7c11919b95a33a Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Thu, 21 Mar 2024 16:35:35 -0700 Subject: [PATCH 02/20] feat: support adaptive_avg_pool1d dynamo converter (#2614) --- .../dynamo/conversion/aten_ops_converters.py | 18 +++ .../dynamo/conversion/impl/pool.py | 67 ++++++++- .../conversion/test_adaptive_avgpool_aten.py | 139 +++++++----------- 3 files changed, 141 insertions(+), 83 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 72998e1917..daea3fe385 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2200,6 +2200,24 @@ def aten_ops_avg_pool( ) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default) +def aten_ops_adaptive_avg_pool( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.adaptive_avg_pool1d( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + output_size=args[1], + ) + + def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 13c8645a90..8c16f59030 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -1,6 +1,8 @@ -from typing import Optional, Sequence, Union +import math +from typing import Dict, Optional, Sequence, Union import tensorrt as trt +import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -104,3 +106,66 @@ def max_poolNd( set_layer_name(pool_layer, target, name, source_ir) return pool_layer.get_output(0) + + +def adaptive_avg_pool1d( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + output_size: Union[int, Sequence[int]], +) -> TRTTensor: + def start_index(idx: int, out_dim: int, in_dim: int) -> int: + """Calculate the start index of each pooling window""" + return math.floor((float(idx) * float(in_dim)) / out_dim) + + def end_index(idx: int, out_dim: int, in_dim: int) -> int: + """Calculate the end index of each pooling window""" + return math.ceil((float(idx + 1) * float(in_dim)) / out_dim) + + in_dim = input.shape[-1] + out_dim = output_size if isinstance(output_size, int) else output_size[0] + output_list = [] + + # store {index: slice} for reducing repeated slice ops + idx_slice_map: Dict[int, TRTTensor] = {} + # iterate over each output dimension + for i in range(out_dim): + # calculate the start and end index of each pooling window + start = start_index(i, out_dim, in_dim) + end = end_index(i, out_dim, in_dim) + + # slice the input tensor from start to end index, the result of which is the window waiting for pooling + slices = [] + for j in range(start, end): + if j in idx_slice_map: + slice = idx_slice_map[j] + else: + slice = impl.select.select( + ctx, target, source_ir, f"{name}_select_{j}", input, -1, j + ) + slice = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_{i}_{j}", + slice, + (*slice.shape, 1), + ) + idx_slice_map[j] = slice + + slices.append(slice) + + slices = impl.cat.cat( + ctx, target, source_ir, f"{name}_slices_cat_{i}", slices, dim=-1 + ) + # calculate the mean of the slices (average pooling output) and append to the output list + output_list.append( + impl.reduce.mean( + ctx, target, source_ir, f"{name}_sum_{i}", slices, dim=-1, keepdim=True + ) + ) + + output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1) + return output diff --git a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py index e19e1b6187..3d48409631 100644 --- a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py +++ b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py @@ -9,102 +9,77 @@ class TestAdaptiveAvgPoolConverter(DispatchTestCase): @parameterized.expand( [ - ((64, 64),), - ((128, 64),), - # (64,), This case has been there in previous code but it isn't a valid pytorch code. - ] - ) - def test_adaptive_avgpool( - self, - output_size, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d(output_size) - - def forward(self, x): - return self.pool(x) - - inputs = [torch.randn(1, 3, 256, 256)] - self.run_test( - TestModule(), - inputs, - use_dynamo_tracer=True, - ) - - def test_adaptive_avgpool_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - Input( - shape=(-1, -1, 256, 256), - dtype=torch.float32, - shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ( + (2, 3), + 2, + ), + ( + (2, 8), + 8, + ), + ( + (1, 2, 3), + 2, + ), + ( + (2, 2, 8), + 16, + ), + ( + (2, 3), + (1,), + ), + ( + (2, 3), + (2,), + ), + ( + (2, 8), + (4,), + ), + ( + (2, 8), + (16,), + ), + ( + (2, 3, 1), + (1,), + ), + ( + (2, 3, 2), + (2,), + ), + ( + (2, 3, 4), + (4,), + ), + ( + (2, 2, 32), + (31,), + ), + ( + (2, 2, 32), + (64,), ), - ] - self.run_test_with_dynamic_shape( - TestModule(), input_specs, use_dynamo_tracer=True - ) - - @parameterized.expand( - [ - ((16, 16, 16),), - ((32, 16, 4),), - (32,), ] ) - def test_adaptive_avgpool3d( + def test_adaptive_avg_pool1d( self, + input_shape, output_size, ): class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d(output_size) - def forward(self, x): - return self.pool(x) + return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size) - inputs = [torch.randn(1, 3, 32, 64, 64)] + inputs = [torch.randn(input_shape)] self.run_test( TestModule(), inputs, - use_dynamo_tracer=True, + # use_dynamo_tracer=True, + enable_passes=True, ) - def test_adaptive_avgpool3d_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) - - def forward(self, x): - return self.pool(x) - - input_specs = [ - Input( - shape=(-1, -1, 32, 64, 64), - dtype=torch.float32, - shape_ranges=[ - ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) - ], - ), - ] - self.run_test_with_dynamic_shape( - TestModule(), - input_specs, - use_dynamo_tracer=True, - ) - - # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." - if __name__ == "__main__": run_tests() From e150913e99b4922b2c4b208945b55ac915b55061 Mon Sep 17 00:00:00 2001 From: MizuKuma <33080670+Arktische@users.noreply.github.com> Date: Fri, 22 Mar 2024 09:05:21 +0800 Subject: [PATCH 03/20] fix: Add cmake missing source file ref for core_lowering.passes (#2672) --- core/lowering/passes/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/core/lowering/passes/CMakeLists.txt b/core/lowering/passes/CMakeLists.txt index 25b320c6c3..085faf3e0d 100644 --- a/core/lowering/passes/CMakeLists.txt +++ b/core/lowering/passes/CMakeLists.txt @@ -26,6 +26,7 @@ target_sources(${lib_name} "${CMAKE_CURRENT_SOURCE_DIR}/unpack_rsqrt.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/unpack_scaled_dot_product_attention.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp" ) From 4bb60f3471ee1e3b8ddc91018dff8d7ab67ea483 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Fri, 5 Apr 2024 09:00:39 +0800 Subject: [PATCH 04/20] Add support for `aten.pixel_unshuffle` dynamo converter (#2696) --- .../dynamo/conversion/aten_ops_converters.py | 23 ++++++++++ .../dynamo/conversion/impl/shuffle.py | 44 +++++++++++++++++++ .../conversion/test_pixel_unshuffle_aten.py | 29 ++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index daea3fe385..67380387db 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2335,6 +2335,29 @@ def aten_ops_pixel_shuffle( ) +@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pixel_unshuffle( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.shuffle.pixel_unshuffle( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @enforce_tensor_types({0: (TRTTensor,)}) @dynamo_tensorrt_converter(torch.ops.aten.argmax.default) def aten_ops_argmax( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 6d848c4be3..b2a79af5cb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -76,3 +76,47 @@ def pixel_shuffle( permuted_tensor, shape[:-3] + (out_channels, out_height, out_width), ) + + +def pixel_unshuffle( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + downscale_factor: int, +) -> TRTTensor: + shape = input.shape + in_channels, in_height, in_width = shape[-3:] + out_channels = in_channels * (downscale_factor**2) + out_height = in_height // downscale_factor + out_width = in_width // downscale_factor + new_shape = shape[:-3] + ( + in_channels, + out_height, + downscale_factor, + out_width, + downscale_factor, + ) + reshaped_tensor = reshape( + ctx, target, source_ir, f"{name}_reshape1", input, new_shape + ) + rank = len(new_shape) + permute_shape = tuple(range(rank - 5)) + ( + rank - 5, # in_channels + rank - 3, # downscale_factor + rank - 1, # downscale_factor + rank - 4, # out_height + rank - 2, # out_width + ) + permuted_tensor = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape + ) + return reshape( + ctx, + target, + source_ir, + f"{name}_reshape2", + permuted_tensor, + shape[:-3] + (out_channels, out_height, out_width), + ) diff --git a/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py b/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py new file mode 100644 index 0000000000..fb93e68499 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py @@ -0,0 +1,29 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestPixelUnshuffleConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 1, 1), 1), + ((1, 1, 12, 12), 3), + ((2, 3, 4, 25, 30), 5), + ] + ) + def test_pixel_unshuffle(self, shape, downscale_factor): + class PixelUnshuffle(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.pixel_unshuffle.default(x, downscale_factor) + + inputs = [torch.randn(shape)] + self.run_test( + PixelUnshuffle(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 4314fbc8de533159c7a6f0289fb5cc8220ee1458 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 12 Apr 2024 09:36:05 +0900 Subject: [PATCH 05/20] feat: support aten.atan2 converter (#2689) --- .../dynamo/conversion/aten_ops_converters.py | 24 +++ .../dynamo/conversion/impl/elementwise/ops.py | 179 +++++++++++++++++- tests/py/dynamo/conversion/test_atan2_aten.py | 132 +++++++++++++ 3 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 tests/py/dynamo/conversion/test_atan2_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 67380387db..d99ba6ef4b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1407,6 +1407,30 @@ def aten_ops_atanh( ) +@dynamo_tensorrt_converter(torch.ops.aten.atan2.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 1: (TRTTensor,), + } +) +def aten_ops_atan2( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.atan2( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.ceil.default) def aten_ops_ceil( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index 81c3a3e867..9a5087e469 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -1,5 +1,6 @@ from typing import Optional, Union +import numpy as np import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl @@ -10,13 +11,15 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_int_int_div_trt_tensor, cast_int_or_float_to_bool, + cast_trt_tensor, get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) -from torch_tensorrt.dynamo.conversion.impl.unary import sign +from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.converter_utils import broadcast from torch_tensorrt.fx.types import TRTTensor @@ -213,6 +216,180 @@ def remainder( return fmod2_value +def atan2( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + """ + Perform atan2 operation on Tensor, calculating the arctangent of the quotient of input tensors. + atan2(x,y) = atan(x/y) if y > 0, + = atan(x/y) + π if x ≥ 0 and y < 0, + = atan(x/y) - π if x < 0 and y < 0, + = π/2 if x > 0 and y = 0, + = -π/2 if x < 0 and y = 0, + = 0 if x = 0 and y = 0 + + Args: + ctx: ConversionContext. + target: node target + source_ir (SourceIR): Source IR calling the function. + name: namespace for the op + input: Tensor or constant representing the dividend. + other: Tensor or constant representing the divisor. + + Returns: + A TensorRT tensor representing the result of the atan2 operation. + """ + pi_value = 3.141592653589793 + pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi") + + if isinstance(input, TRTTensor): + input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input") + if isinstance(other, TRTTensor): + other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other") + + input, other = broadcast(ctx.net, input, other, f"{name}_input", f"{name}_other") + + # Calculate x_zero, y_zero (whether inputs are zero) + x_zero = eq(ctx, target, source_ir, f"{name}_x_zero", input, 0) + y_zero = eq(ctx, target, source_ir, f"{name}_y_zero", other, 0) + + # Get sign of inputs + x_positive = gt(ctx, target, source_ir, f"{name}_x_positive", input, 0) + x_zero_positive = ge(ctx, target, source_ir, f"{name}_x_zero_positive", input, 0) + x_negative = lt(ctx, target, source_ir, f"{name}_x_negative", input, 0) + y_positive = gt(ctx, target, source_ir, f"{name}_y_positive", other, 0) + y_negative = lt(ctx, target, source_ir, f"{name}_y_negative", other, 0) + + # Calculate atan(x/y) + input_div_other = div( + ctx, target, source_ir, f"{name}_input_div_other", input, other + ) + atan_val = atan(ctx, target, source_ir, f"{name}_atan", input_div_other) + + # atan(x/y)+π if x≥0 and y<0, + atan_add_pi = add( + ctx, target, source_ir, f"{name}_atan_add_pi", atan_val, pi_tensor + ) + + # atan(x/y)-π if x<0 and y<0, + atan_sub_pi = sub( + ctx, target, source_ir, f"{name}_atan_sub_pi", atan_val, pi_tensor + ) + + # atan(x/y)+π if x≥0 and y<0, + atan_corrected = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_atan_corrected", + atan_add_pi, + atan_val, + logical_and( + ctx, + target, + source_ir, + f"{name}_x_zero_positive_and_y_negative", + x_zero_positive, + y_negative, + ), + ) + + # atan(x/y)-π if x<0 and y<0, + atan_corrected_2 = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_atan_corrected_2", + atan_sub_pi, + atan_corrected, + logical_and( + ctx, + target, + source_ir, + f"{name}_x_negative_and_y_negative", + x_negative, + y_negative, + ), + ) + + # atan(x/y) if y>0 + atan_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_atan_output", + atan_val, + atan_corrected_2, + y_positive, + ) + + # on x or y-axis + pi_over_2_tensor = get_trt_tensor( + ctx, + (pi_value / 2) * np.ones(input.shape, dtype=np.float32), + f"{name}_pi_over_2_tensor", + dtype=trt.float32, + ) + minus_pi_over_2_tensor = get_trt_tensor( + ctx, + (-pi_value / 2) * np.ones(input.shape, dtype=np.float32), + f"{name}_minus_pi_over_2_tensor", + dtype=trt.float32, + ) + zero_tensor = get_trt_tensor( + ctx, + np.zeros(input.shape, dtype=np.float32), + f"{name}_zero_tensor", + dtype=trt.float32, + ) + + # π/2 if x>0 and y=0, + pi_over_2_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_pi_over_2_output", + pi_over_2_tensor, + atan_output, + logical_and( + ctx, target, source_ir, f"{name}_x_zero_and_y_positive", x_positive, y_zero + ), + ) + + # -π/2 if x<0 and y=0, + minus_pi_over_2_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_minus_pi_over_2_output", + minus_pi_over_2_tensor, + pi_over_2_output, + logical_and( + ctx, target, source_ir, f"{name}_x_zero_and_y_negative", x_negative, y_zero + ), + ) + + # 0 if x=0 and y=0, + zero_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_zero_output", + zero_tensor, + minus_pi_over_2_output, + logical_and( + ctx, target, source_ir, f"{name}_x_zero_and_y_zero", y_zero, x_zero + ), + ) + + return zero_output + + def clamp( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_atan2_aten.py b/tests/py/dynamo/conversion/test_atan2_aten.py new file mode 100644 index 0000000000..550ade2970 --- /dev/null +++ b/tests/py/dynamo/conversion/test_atan2_aten.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestAtan2Converter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_atan2_lhs_const(self, input_shape, dtype): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.randn(input_shape, dtype=dtype), + torch.rand(1), + ] + + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_atan2_rhs_const(self, input_shape, dtype): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.rand(1), + torch.randn(input_shape, dtype=dtype), + ] + + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_atan2_float(self, input_shape, dtype): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.randn(input_shape, dtype=dtype), + torch.randn(input_shape, dtype=dtype), + ] + + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + ((50,), torch.int, -5, 5), + ((1, 20), torch.int32, -5, 5), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_atan2_int(self, input_shape, dtype, low, high): + class atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + inputs = [ + torch.randint(low, high, input_shape, dtype=dtype), + torch.randint(low, high, input_shape, dtype=dtype), + ] + self.run_test( + atan2(), + inputs, + ) + + @parameterized.expand( + [ + (torch.float, 0.0, 0.0), + (torch.float, 0.0, torch.rand(1)), + (torch.float, torch.rand(1), 0.0), + (torch.int, 0, 0), + (torch.int, 0, torch.randint(-5, 5, (1,))), + (torch.int, torch.randint(1, 10, (1,)), 0), + ] + ) + def test_atan2_zero(self, dtype, x_val, y_val): + class Atan2(nn.Module): + def forward(self, lhs_val, rhs_val): + return torch.ops.aten.atan2.default(lhs_val, rhs_val) + + if isinstance(x_val, torch.Tensor): + x_val = x_val.item() + if isinstance(y_val, torch.Tensor): + y_val = y_val.item() + + inputs = [ + torch.tensor([x_val], dtype=dtype), + torch.tensor([y_val], dtype=dtype), + ] + + self.run_test( + Atan2(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From a9a627292fbf0709baba9de00d4a742d10a50956 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 12 Apr 2024 09:46:05 +0900 Subject: [PATCH 06/20] feat: support aten.index_select converter (#2710) --- .../dynamo/conversion/aten_ops_converters.py | 25 +++++++++++ .../conversion/test_index_select_aten.py | 41 +++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_index_select_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d99ba6ef4b..b03d67d6c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2845,3 +2845,28 @@ def aten_ops_roll( args[1], args_bounds_check(args, 2, []), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.index_select.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 2: (TRTTensor,), + } +) +def aten_ops_index_select( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.index_select( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + ) diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py new file mode 100644 index 0000000000..83eaedb944 --- /dev/null +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIndexSelectConverter(DispatchTestCase): + @parameterized.expand( + [ + ("1d_input", (10,), 0, (1,)), + ("2d_input_dim_0", (10, 3), 0, (0, 2)), + ("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), + ("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)), + ("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), + ("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), + ("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)), + ("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)), + ] + ) + def test_index_select(self, _, source_shape, dim, indices_val): + class TestIndexSelect(torch.nn.Module): + def forward(self, source_tensor, indices_tensor): + return torch.ops.aten.index_select.default( + source_tensor, dim, indices_tensor + ) + + input = [ + torch.randn(*source_shape, dtype=torch.float32), + torch.tensor([*indices_val], dtype=torch.int32), + ] + + self.run_test( + TestIndexSelect(), + input, + ) + + +if __name__ == "__main__": + run_tests() From f5b7b318e6e1fd4bd50498e34dce89b7ec66fe28 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 12 Apr 2024 09:47:14 +0900 Subject: [PATCH 07/20] feat: support aten.isnan converter (#2711) --- .../dynamo/conversion/aten_ops_converters.py | 17 ++++ .../dynamo/conversion/impl/unary/ops.py | 20 +++++ tests/py/dynamo/conversion/test_isnan_aten.py | 82 +++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_isnan_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b03d67d6c7..91983e3e6b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1533,6 +1533,23 @@ def aten_ops_isinf( ) +@dynamo_tensorrt_converter(torch.ops.aten.isnan.default) +def aten_ops_isnan( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.isnan( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) def aten_ops_add( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 554640ea5a..4bc24051ee 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -508,3 +508,23 @@ def scalar_tensor( identity_layer = ctx.net.add_identity(tensor) set_layer_name(identity_layer, target, name, source_ir) return identity_layer.get_output(0) + + +def isnan( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # False for NaN elements since NaN is not equal to anything, including itself. + equality_result = impl.elementwise.eq( + ctx, target, source_ir, f"{name}_eq_nan", input, input + ) + + # Invert equality_result to get a mask where NaN values are marked as True. + nan_values_mask = logical_not( + ctx, target, source_ir, f"{name}_logical_not", equality_result + ) + + return nan_values_mask diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py new file mode 100644 index 0000000000..5651b0ca25 --- /dev/null +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIsNanConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + torch.tensor( + [ + 1.23, + float("nan"), + -4.56, + float("inf"), + float("-inf"), + -100.0, + float("nan"), + 0.13, + -0.13, + 3.14159265, + ] + ), + ), + ] + ) + def test_isnan_float(self, data): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [data] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + (torch.full((2, 2), float("nan"), dtype=torch.float32),), + (torch.full((3, 10, 5), float("nan"), dtype=torch.float32),), + (torch.randn((5, 10, 5), dtype=torch.float32),), + ] + ) + def test_isnan_dim(self, data): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [data] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_isnan_int(self, input_shape, dtype, low, high): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() From 264905ddaf15f5b18986ac4c1390a376e19bd59b Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Fri, 12 Apr 2024 09:06:58 +0800 Subject: [PATCH 08/20] feat: support adaptive avg pool 2d and 3d dynamo converters (#2632) --- .../dynamo/conversion/aten_ops_converters.py | 33 ++- .../dynamo/conversion/impl/pool.py | 230 +++++++++++++++++- .../conversion/test_adaptive_avgpool_aten.py | 217 ++++++++++++++++- 3 files changed, 477 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 91983e3e6b..8c20b06223 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2242,7 +2242,12 @@ def aten_ops_avg_pool( @dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default) -def aten_ops_adaptive_avg_pool( +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_adaptive_avg_pool1d( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], @@ -2259,6 +2264,32 @@ def aten_ops_adaptive_avg_pool( ) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default) +@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default) +@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_adaptive_avg_poolNd( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pool.adaptive_avg_poolNd( + ctx, + target, + source_ir=SourceIR.ATEN, + name=name, + input=args[0], + output_size=args[1], + ) + + def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 8c16f59030..c21ccc1c59 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -6,7 +6,10 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.dynamo.conversion.converter_utils import ( + extend_attr_to_tuple, + get_positive_dim, +) from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, @@ -169,3 +172,228 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int: output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1) return output + + +def adaptive_avg_poolNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + output_size: Sequence[int], +) -> TRTTensor: + input_shape = input.shape + input_rank = len(input_shape) + output_rank = len(output_size) + need_reshape_back = False + + if input_rank == output_rank + 1: # reshape to 4D/5D for TRT pooling + input = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape) + ) + need_reshape_back = True + input_shape = input.shape + input_rank = len(input_shape) + + extend_len = len(output_size) + output_size = list(output_size) + original_input = input + + # repeat_interleave the input if the dim of output is larger than input + insert_axises = [] + for axis in range(1, extend_len + 1): + axis = -axis + positive_axis = get_positive_dim( + axis, input_rank + ) # convert to positive axis, which is for calculating new shapes below + input_dim = input_shape[axis] + output_dim = output_size[axis] + diff = output_dim - input_dim + if diff > 0: # the dim of output is larger than input + times = output_dim // input_dim + remainder = output_dim % input_dim + if ( + diff == 2 and remainder == 2 + ): # case 1: output_dim - input_dim == 2 and is not an integral multiple + insert_axises.append(axis) + remainder -= 1 + output_size[axis] -= 1 + + if ( + remainder + 1 == input_dim + ): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input + remainder = 0 + times += 1 + + flags = [] # record the axis that needs to be repeated + concat_list = [] + for j in range( + input_dim + ): # iterate the input dim to see which dim needs to be repeated or not + single_elem = impl.select.select( + ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j + ) + new_shape = list(single_elem.shape) + new_shape.insert(positive_axis, 1) + single_elem = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_{axis}_{j}", + single_elem, + new_shape, + ) + if remainder > 0 or j in flags: + concat_list.extend([single_elem] * (times + 1)) + remainder -= 2 + flags.append(input_dim - j - 1) + else: + concat_list.extend([single_elem] * times) + out = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}_{j}", concat_list, axis + ) + input = out + + stride = tuple( + input.shape[-extend_len + i] // output_size[i] for i in range(extend_len) + ) + kernel_size = tuple( + input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i] + for i in range(extend_len) + ) + + # Don't have to pool, directly return + if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size): + if need_reshape_back: # reshape back + input = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_back", + input, + (*input.shape[1:],), + ) + return input + + layer = ctx.net.add_pooling_nd( + input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size + ) + layer.stride_nd = stride + set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir) + + output = layer.get_output(0) + + # For case 1, we need to split the output and insert the mid of input + for axis in insert_axises: + positive_axis = get_positive_dim(axis, input_rank) + input_dim = input_shape[axis] + output_dim = output_size[axis] + if input_dim % 2 == 1: + prev_one = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_prev_one_{axis}", + output, + axis, + output_dim // 2 - 1, + ) + extend_shape = list(prev_one.shape) + extend_shape.insert(positive_axis, 1) + prev_one = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_extend_shape_{axis}", + prev_one, + extend_shape, + ) + prev_two = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_prev_two_{axis}", + output, + axis, + output_dim // 2 - 2, + ) + prev_two = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_two_shape_reshape_{axis}", + prev_two, + extend_shape, + ) + prev_one_two_diff = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_prev_one_two_diff_{axis}", + prev_one, + prev_two, + ) + + mid = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_mid_{axis}", + prev_one, + prev_one_two_diff, + ) + split_output = impl.split.split( + ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis + ) + split_output.insert(1, mid) + output = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis + ) + else: + mid1 = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_{axis}", + original_input, + axis, + input_dim // 2 - 1, + ) + new_shape = list(mid1.shape) + new_shape.insert(positive_axis, 1) + mid1 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape + ) + mid2 = impl.select.select( + ctx, + target, + source_ir, + f"{name}_select_{axis}", + original_input, + axis, + input_dim // 2, + ) + mid2 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape + ) + split_output = impl.split.split( + ctx, + target, + source_ir, + f"{name}_split_{axis}", + output, + [output_dim // 2, 1, output_dim // 2], + axis, + ) + split_output[1] = mid1 + split_output.insert(2, mid2) + output = impl.cat.cat( + ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis + ) + + if need_reshape_back: # reshape back + output = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],) + ) + + return output diff --git a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py index 3d48409631..b8dc1e1968 100644 --- a/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py +++ b/tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py @@ -76,10 +76,225 @@ def forward(self, x): self.run_test( TestModule(), inputs, - # use_dynamo_tracer=True, enable_passes=True, ) + @parameterized.expand( + [ + # 3d input + ( + (1, 2, 3), + (1, 2), + ), + ( + (1, 2, 3), + (2, 3), + ), + ( + (1, 2, 8), + (4, 4), + ), + ( + (2, 3, 2), + (5, 3), + ), + ( + (2, 8, 16), + (4, 8), + ), + ( + (2, 8, 16), + (8, 8), + ), + # 4d input + ( + (1, 1, 4, 3), + (4, 8), + ), + ( + (3, 2, 3, 2), + (1, 5), + ), + ( + (4, 2, 2, 8), + (5, 2), + ), + ( + (3, 2, 3, 3), + (6, 4), + ), + ( + (1, 2, 3, 2), + (2, 2), + ), + ( + (2, 2, 32, 16), + (8, 8), + ), + ( + (2, 2, 32, 32), + (31, 16), + ), + ( + (1, 1, 64, 64), + (64, 16), + ), + ] + ) + def test_adaptive_avg_pool2d( + self, + input_shape, + output_size, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((1, 2),), + ] + ) + def test_adaptive_avg_pool2d_dynamic(self, output_size): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = torch.ops.aten.adaptive_avg_pool2d.default(x, output_size) + return out + + input_specs = [ + Input( + shape=(-1, 2, 3, 2), + dtype=torch.float32, + shape_ranges=[((1, 2, 3, 2), (3, 2, 3, 2), (10, 2, 3, 2))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + + @parameterized.expand( + [ + # 4d input + ( + (1, 1, 4, 3), + (4, 8, 2), + ), + ( + (1, 2, 3, 1), + (1, 5, 2), + ), + ( + (1, 2, 3, 2), + (1, 5, 3), + ), + ( + (4, 2, 2, 8), + (8, 5, 2), + ), + ( + (3, 2, 3, 3), + (6, 4, 1), + ), + ( + (1, 2, 3, 2), + (2, 2, 2), + ), + ( + (2, 2, 32, 16), + (8, 8, 8), + ), + ( + (2, 2, 32, 32), + (31, 16, 64), + ), + ( + (1, 1, 64, 64), + (64, 16, 1), + ), + # 5d input + ( + (1, 1, 1, 4, 3), + (4, 8, 2), + ), + ( + (4, 3, 1, 2, 3), + (2, 4, 6), + ), + ( + (1, 4, 2, 2, 2), + (5, 2, 4), + ), + ( + (3, 2, 3, 3, 2), + (6, 4, 1), + ), + ( + (2, 2, 32, 16, 8), + (8, 8, 8), + ), + ( + (2, 2, 32, 32, 32), + (31, 16, 64), + ), + ( + (1, 1, 64, 64, 64), + (64, 16, 1), + ), + ] + ) + def test_adaptive_avgpool3d( + self, + input_shape, + output_size, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.adaptive_avg_pool3d.default(x, output_size) + + inputs = [torch.randn(input_shape)] + self.run_test( + TestModule(), + inputs, + enable_passes=True, + ) + + @parameterized.expand( + [ + ((1, 2, 3),), + ] + ) + def test_adaptive_avg_pool3d_dynamic(self, output_size): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + out = torch.ops.aten.adaptive_avg_pool3d.default(x, output_size) + return out + + input_specs = [ + Input( + shape=(-1, 2, 3, 1, 4), + dtype=torch.float32, + shape_ranges=[((1, 2, 3, 1, 4), (3, 2, 3, 1, 4), (10, 2, 3, 1, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + if __name__ == "__main__": run_tests() From fc29af410729f72c80ccecbe3b47969fa8f2900d Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Fri, 12 Apr 2024 10:38:30 +0900 Subject: [PATCH 09/20] feat: support aten.expm1 converter (#2714) --- .../dynamo/conversion/aten_ops_converters.py | 17 +++++ .../dynamo/conversion/impl/unary/ops.py | 26 +++++++ tests/py/dynamo/conversion/test_expm1_aten.py | 69 +++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_expm1_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8c20b06223..c5e3a919eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1152,6 +1152,23 @@ def aten_ops_exp( ) +@dynamo_tensorrt_converter(torch.ops.aten.expm1.default) +def aten_ops_expm1( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.expm1( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.log.default) def aten_ops_log( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 4bc24051ee..9f2ad07612 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -44,6 +44,32 @@ def exp( ) +def expm1( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Computes e^x - 1 for each element of the input tensor. + + Args: + ctx (ConversionContext): TensorRT ConversionContext object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + TRTTensor: A TensorRT tensor represent the result of expm1 operator. + """ + # Compute e^x for each element of the input tensor + exp_result = exp(ctx, target, source_ir, f"{name}_exp", input_val) + + return impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", exp_result, 1) + + def log( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_expm1_aten.py b/tests/py/dynamo/conversion/test_expm1_aten.py new file mode 100644 index 0000000000..e695a27475 --- /dev/null +++ b/tests/py/dynamo/conversion/test_expm1_aten.py @@ -0,0 +1,69 @@ +from math import exp + +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestExpConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.float), + ((1, 20), torch.float), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_expm1_float(self, input_shape, dtype): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [torch.randn(input_shape, dtype=dtype)] + self.run_test( + expm1(), + inputs, + ) + + @parameterized.expand( + [ + (torch.full((1, 20), exp(1), dtype=torch.float),), + (torch.full((2, 3, 4), exp(2), dtype=torch.float),), + (torch.full((2, 3, 4, 5), exp(3), dtype=torch.float),), + ] + ) + def test_expm1_exp_const_float(self, data): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [data] + self.run_test( + expm1(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_exp_int(self, input_shape, dtype, low, high): + class expm1(nn.Module): + def forward(self, input): + return torch.ops.aten.expm1.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + expm1(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From 4abec39c07391c5328231f2f7c05487a287403bb Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 11 Apr 2024 19:18:12 -0700 Subject: [PATCH 10/20] fix: Add dependencies to Docker container for `apt` versioning TRT (#2746) --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 60b213b110..16b92bbd17 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -47,7 +47,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/ RUN add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" RUN apt-get update -RUN apt-get install -y libnvinfer8=${TENSORRT_VERSION}.* libnvinfer-plugin8=${TENSORRT_VERSION}.* libnvinfer-dev=${TENSORRT_VERSION}.* libnvinfer-plugin-dev=${TENSORRT_VERSION}.* libnvonnxparsers8=${TENSORRT_VERSION}.* libnvonnxparsers-dev=${TENSORRT_VERSION}.* libnvparsers8=${TENSORRT_VERSION}.* libnvparsers-dev=${TENSORRT_VERSION}.* +RUN apt-get install -y libnvinfer8=${TENSORRT_VERSION}.* libnvinfer-plugin8=${TENSORRT_VERSION}.* libnvinfer-dev=${TENSORRT_VERSION}.* libnvinfer-plugin-dev=${TENSORRT_VERSION}.* libnvonnxparsers8=${TENSORRT_VERSION}.* libnvonnxparsers-dev=${TENSORRT_VERSION}.* libnvparsers8=${TENSORRT_VERSION}.* libnvparsers-dev=${TENSORRT_VERSION}.* libnvinfer-headers-dev=${TENSORRT_VERSION}.* libnvinfer-headers-plugin-dev=${TENSORRT_VERSION}.* # Setup Bazel via Bazelisk RUN wget -q https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 -O /usr/bin/bazel &&\ From 25a3b28128befa140ad673f7aab78c9c56bdba78 Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Wed, 17 Apr 2024 04:11:56 +0800 Subject: [PATCH 11/20] fix: param bug in `test_binary_ops_aten` (#2733) --- tests/py/dynamo/conversion/test_binary_ops_aten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_binary_ops_aten.py b/tests/py/dynamo/conversion/test_binary_ops_aten.py index 331fab591d..ebe727b716 100644 --- a/tests/py/dynamo/conversion/test_binary_ops_aten.py +++ b/tests/py/dynamo/conversion/test_binary_ops_aten.py @@ -116,7 +116,7 @@ def forward(self, x): inputs = [torch.randn(2, 2)] self.run_test(m, inputs) - @parameterized.expand([((lambda x, y: torch.ops.aten.div.Tensor(x, y)))]) + @parameterized.expand([(lambda x, y: torch.ops.aten.div.Tensor(x, y),)]) def test_elementwise_op_div_with_two_ints(self, orig_op: Callable): class TestModule(nn.Module): def __init__(self, orig_op): @@ -130,7 +130,7 @@ def forward(self, x): inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)] self.run_test(m, inputs) - @parameterized.expand([(lambda x, y: torch.ops.aten.div.Tensor(x, y))]) + @parameterized.expand([(lambda x, y: torch.ops.aten.div.Tensor(x, y),)]) def test_elementwise_op_div_with_one_int_one_constant(self, orig_op: Callable): class TestModule(nn.Module): def __init__(self, orig_op): From 822e63c1d84619b270352fecf69bd244c4288528 Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:29:03 -0700 Subject: [PATCH 12/20] aten::empty_like (#2654) --- py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index de791851db..98c25a1f54 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -37,6 +37,7 @@ aten.elu_backward, aten._embedding_bag, aten.embedding_dense_backward, + aten.empty_like, aten._euclidean_dist.default, aten.expand_as, aten.eye, From dee74c4da643357abb4d2aee8830fa76a1422b76 Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Tue, 16 Apr 2024 22:24:08 -0700 Subject: [PATCH 13/20] empty_permute decomposition (#2698) --- .../dynamo/lowering/_decompositions.py | 12 ++++ .../py/dynamo/lowering/test_decompositions.py | 65 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 981c80f9fa..9ba7ec964b 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -162,6 +162,18 @@ def var_decomposition( return variance +@register_torch_trt_decomposition( + torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: + empty_size = args[0] + empty_permute = args[1] + perm = [0] * len(empty_size) + for permute_index, permute_element in enumerate(empty_permute): + perm[permute_element] = permute_index + return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 84e8d11585..457e9e2e81 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -420,6 +420,71 @@ def forward(self, x): f"MaxPool3d TRT outputs don't match with the original model.", ) + def test_lowering_empty_like_module(self): + class emptyLike(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + c = torch.ops.aten.add(x, x) + y = torch.ops.aten.empty_like.default(c) + d = y + c + return d + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.add.Tensor} + unexpected_ops = { + torch.ops.aten.empty_like.default, + torch.ops.aten.empty_permuted.default, + } + + inputs = [torch.zeros(3, 2).cuda()] + + fx_graph = torch.fx.symbolic_trace(emptyLike()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_long_and_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Select_scatter TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() From 77c4b964c4853ba9f0b1433ac58c9b4b29a5af0c Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Tue, 16 Apr 2024 22:43:17 -0700 Subject: [PATCH 14/20] Removing grid lowering (#2686) --- .../dynamo/conversion/aten_ops_converters.py | 2 + .../dynamo/lowering/_decomposition_groups.py | 1 - tests/py/dynamo/conversion/test_grid_aten.py | 171 +++++++++++------- 3 files changed, 108 insertions(+), 66 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c5e3a919eb..32b9691f1f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -332,6 +332,8 @@ def aten_ops_fmod( @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default) +@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 98c25a1f54..84d9af5939 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -47,7 +47,6 @@ aten.gelu, aten.gelu_backward, aten.glu_backward, - aten.grid_sampler_2d, aten.hardshrink, aten.hardshrink_backward, aten.hardsigmoid, diff --git a/tests/py/dynamo/conversion/test_grid_aten.py b/tests/py/dynamo/conversion/test_grid_aten.py index 32480110f3..e3b5783b19 100644 --- a/tests/py/dynamo/conversion/test_grid_aten.py +++ b/tests/py/dynamo/conversion/test_grid_aten.py @@ -6,112 +6,74 @@ from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input +grid_sampler_aten_ops = { + "torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler, + "torch.ops.aten.grid_sampler_2d": torch.ops.aten.grid_sampler_2d, + "torch.ops.aten.grid_sampler.default": torch.ops.aten.grid_sampler.default, + "torch.ops.aten.grid_sampler_2d.default": torch.ops.aten.grid_sampler_2d.default, +} + grid_sampler_ops = [ ( "input_grid_interpolation_nearest_sample_fill", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 0, 0, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_nearest_sample_clamp", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 0, 1, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_nearest_sample_reflect", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 0, 2, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_linear_sample_fill", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 1, 0, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_linear_sample_clamp", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 1, 1, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_linear_sample_reflect", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 1, 2, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_cubic_sample_fill", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 2, 0, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_cubic_sample_clamp", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 2, 1, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), ( "input_grid_interpolation_cubic_sample_reflect", - (lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_nearest_sample_fill_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_nearest_sample_clamp_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_nearest_sample_reflect_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_linear_sample_fill_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_linear_sample_clamp_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_linear_sample_reflect_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_cubic_sample_fill_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_cubic_sample_clamp_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)), - [1, 1, 5, 5], - [1, 5, 2, 2], - ), - ( - "input_grid_interpolation_cubic_sample_reflect_2d", - (lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)), + "torch.ops.aten.grid_sampler", + (lambda x, grid, op: op(x, grid, 2, 2, True)), [1, 1, 5, 5], [1, 5, 2, 2], ), @@ -126,11 +88,90 @@ class TestGridConverter(DispatchTestCase): grid_sampler_op[1], grid_sampler_op[2], grid_sampler_op[3], + grid_sampler_op[4], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid(self, _, op_name, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1] + "_2d", + grid_sampler_op[2], + grid_sampler_op[3], + grid_sampler_op[4], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid_2d(self, _, op_name, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1] + ".default", + grid_sampler_op[2], + grid_sampler_op[3], + grid_sampler_op[4], + ) + for grid_sampler_op in grid_sampler_ops + ] + ) + def test_grid_default(self, _, op_name, op, input_shape, dim_shape): + class TestModule(nn.Module): + def __init__(self, grid_sampler_op): + super().__init__() + self.grid_sampler_op = grid_sampler_op + + def forward(self, x): + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) + + inputs = [torch.randn(input_shape, dtype=torch.float32)] + grid_model = TestModule(op) + self.run_test(grid_model, inputs) + + @parameterized.expand( + [ + ( + grid_sampler_op[0], + grid_sampler_op[1] + "_2d.default", + grid_sampler_op[2], + grid_sampler_op[3], + grid_sampler_op[4], ) for grid_sampler_op in grid_sampler_ops ] ) - def test_grid(self, _, op, input_shape, dim_shape): + def test_grid_2d_default(self, _, op_name, op, input_shape, dim_shape): class TestModule(nn.Module): def __init__(self, grid_sampler_op): super().__init__() @@ -138,7 +179,7 @@ def __init__(self, grid_sampler_op): def forward(self, x): grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) - return self.grid_sampler_op(x, grid) + return self.grid_sampler_op(x, grid, grid_sampler_aten_ops[op_name]) inputs = [torch.randn(input_shape, dtype=torch.float32)] grid_model = TestModule(op) From 2f8937f6bce72289605016b04dd2bfb86377c5fe Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 18:53:40 -0600 Subject: [PATCH 15/20] chore(deps): bump transformers from 4.33.2 to 4.36.0 in /tools/perf (#2555) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/perf/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/perf/requirements.txt b/tools/perf/requirements.txt index 159e6f5eab..08a732b658 100644 --- a/tools/perf/requirements.txt +++ b/tools/perf/requirements.txt @@ -2,7 +2,7 @@ numpy argparse pyyaml onnx -transformers==4.33.2 +transformers==4.36.0 diffusers==0.21.4 pandas==2.0.1 timm==0.9.8 From 6ea06d91bea0c4049cf501c4a36a3516fea63ab2 Mon Sep 17 00:00:00 2001 From: HolyWu Date: Fri, 19 Apr 2024 09:06:27 +0800 Subject: [PATCH 16/20] Fix upsample converter not properly registered (#2683) --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 32b9691f1f..c566d9de0a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2639,6 +2639,7 @@ def aten_ops_pad( ) +@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default) @dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec) def upsample_nearest2d( ctx: ConversionContext, @@ -2660,6 +2661,7 @@ def upsample_nearest2d( ) +@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default) @dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec) def upsample_bilinear2d( ctx: ConversionContext, From 79f7f380317769430dbb4e68d1b26c005c232ae8 Mon Sep 17 00:00:00 2001 From: Michael Feliz <104801882+mfeliz-cruise@users.noreply.github.com> Date: Thu, 18 Apr 2024 18:08:07 -0700 Subject: [PATCH 17/20] feat: TS Add converter support for aten::grid_sampler (#2717) --- .../converters/impl/interpolate.cpp | 31 ++++++ .../converters/test_interpolate.cpp | 96 +++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/core/conversion/converters/impl/interpolate.cpp b/core/conversion/converters/impl/interpolate.cpp index fad2ca5121..b9a5f631b0 100644 --- a/core/conversion/converters/impl/interpolate.cpp +++ b/core/conversion/converters/impl/interpolate.cpp @@ -520,6 +520,37 @@ auto interpolate_registrations TORCHTRT_UNUSED = resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners); } + return true; + }}) + .pattern( + {"aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto grid = args[1].ITensorOrFreeze(ctx); + auto interpolation_mode = args[2].unwrapToInt(); + auto padding_mode = args[3].unwrapToInt(); + auto align_corners = args[4].unwrapToBool(); + + static const auto sample_map = std::map{ + {0, nvinfer1::SampleMode::kFILL}, + {1, nvinfer1::SampleMode::kCLAMP}, + {2, nvinfer1::SampleMode::kREFLECT}}; + + static const auto interpolation_map = std::map{ + {0, nvinfer1::InterpolationMode::kLINEAR}, + {1, nvinfer1::InterpolationMode::kNEAREST}, + {2, nvinfer1::InterpolationMode::kCUBIC}}; + + auto grid_sample_layer = ctx->net->addGridSample(*in, *grid); + TORCHTRT_CHECK( + grid_sample_layer, "Unable to create grid_sample layer from node: " << util::node_info(n)); + + grid_sample_layer->setAlignCorners(align_corners); + grid_sample_layer->setSampleMode(sample_map.at(padding_mode)); + grid_sample_layer->setInterpolationMode(interpolation_map.at(interpolation_mode)); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], grid_sample_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; }}); diff --git a/tests/core/conversion/converters/test_interpolate.cpp b/tests/core/conversion/converters/test_interpolate.cpp index 22931bf9ec..c3b92c3be1 100644 --- a/tests/core/conversion/converters/test_interpolate.cpp +++ b/tests/core/conversion/converters/test_interpolate.cpp @@ -377,3 +377,99 @@ ATEN_INTERPOLATE_STATIC_ONLY_TEST( %7 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %6) return (%7))IR", std::vector({10, 2, 2, 2, 2})); + +TEST(Converters, GridSampleConvertsCorrectly) { + const auto graph = R"IR( + graph(%input : Tensor, %grid : Tensor): + %5 : int = prim::Constant[value=2]() + %6 : int = prim::Constant[value=2]() + %7 : bool = prim::Constant[value=1]() + %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7) + return (%8))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA); + auto d = at::linspace(-1, 1, 8); + auto mesh = at::meshgrid({d, d}); + auto mesh_x = mesh[0]; + auto mesh_y = mesh[1]; + auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA); + + auto trt_input = input.clone(); + auto trt_grid = grid.clone(); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid}); + + for (size_t i = 0; i < jit_results.size(); i++) { + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6)); + } +} + +TEST(Converters, GridSampleOptions1ConvertsCorrectly) { + const auto graph = R"IR( + graph(%input : Tensor, %grid : Tensor): + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : bool = prim::Constant[value=0]() + %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7) + return (%8))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA); + auto d = at::linspace(-1, 1, 8); + auto mesh = at::meshgrid({d, d}); + auto mesh_x = mesh[0]; + auto mesh_y = mesh[1]; + auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA); + + auto trt_input = input.clone(); + auto trt_grid = grid.clone(); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid}); + + for (size_t i = 0; i < jit_results.size(); i++) { + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6)); + } +} + +TEST(Converters, GridSampleOptions2ConvertsCorrectly) { + const auto graph = R"IR( + graph(%input : Tensor, %grid : Tensor): + %5 : int = prim::Constant[value=0]() + %6 : int = prim::Constant[value=0]() + %7 : bool = prim::Constant[value=0]() + %8 : Tensor = aten::grid_sampler(%input, %grid, %5, %6, %7) + return (%8))IR"; + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto input = at::arange(16).view({1, 1, 4, 4}).to(at::kFloat).to(at::kCUDA); + auto d = at::linspace(-1, 1, 8); + auto mesh = at::meshgrid({d, d}); + auto mesh_x = mesh[0]; + auto mesh_y = mesh[1]; + auto grid = at::stack({mesh_x, mesh_y}, 2).unsqueeze(0).to(at::kCUDA); + + auto trt_input = input.clone(); + auto trt_grid = grid.clone(); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {input, grid}); + + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_input, trt_grid}); + + for (size_t i = 0; i < jit_results.size(); i++) { + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt_results[i], 2e-6)); + } +} \ No newline at end of file From 164b352f7c8fb8085d8772d48c88345722261cb1 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 26 Apr 2024 09:43:45 -0700 Subject: [PATCH 18/20] chore: fix is_nan_test --- tests/py/dynamo/conversion/test_isnan_aten.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py index 5651b0ca25..2efb2ed4c0 100644 --- a/tests/py/dynamo/conversion/test_isnan_aten.py +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -36,7 +36,6 @@ def forward(self, input): self.run_test( isnan(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -55,7 +54,6 @@ def forward(self, input): self.run_test( isnan(), inputs, - output_dtypes=[torch.bool], ) @parameterized.expand( @@ -74,7 +72,6 @@ def forward(self, input): self.run_test( isnan(), inputs, - output_dtypes=[torch.bool], ) From bec91fb91794decf55e30584567f98aaa391b974 Mon Sep 17 00:00:00 2001 From: Aakash Apoorv Date: Fri, 26 Apr 2024 02:07:43 +0200 Subject: [PATCH 19/20] Fix minor grammatical corrections (#2779) --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 875b640304..fcac84be27 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,13 @@ > Ahead of Time (AOT) compiling for PyTorch JIT and FX -Torch-TensorRT is a compiler for PyTorch/TorchScript/FX, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript or FX program into an module targeting a TensorRT engine. Torch-TensorRT operates as a PyTorch extention and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module. +Torch-TensorRT is a compiler for PyTorch/TorchScript/FX, targeting NVIDIA GPUs via NVIDIA's TensorRT Deep Learning Optimizer and Runtime. Unlike PyTorch's Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript or FX program into an module targeting a TensorRT engine. Torch-TensorRT operates as a PyTorch extension and compiles modules that integrate into the JIT runtime seamlessly. After compilation using the optimized graph should feel no different than running a TorchScript module. You also have access to TensorRT's suite of configurations at compile time, so you are able to specify operating precision (FP32/FP16/INT8) and other settings for your module. Resources: - [Documentation](https://nvidia.github.io/Torch-TensorRT/) - [FX path Documentation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) - [Torch-TensorRT Explained in 2 minutes!](https://www.youtube.com/watch?v=TU5BMU6iYZ0&ab_channel=NVIDIADeveloper) -- [Comprehensive Discusion (GTC Event)](https://www.nvidia.com/en-us/on-demand/session/gtcfall21-a31107/) +- [Comprehensive Discussion (GTC Event)](https://www.nvidia.com/en-us/on-demand/session/gtcfall21-a31107/) - [Pre-built Docker Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). To use this container, make an NGC account and sign in to NVIDIA's registry with an API key. Refer to [this guide](https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#registering-activating-ngc-account) for the same. ## NVIDIA NGC Container @@ -44,7 +44,7 @@ If you would like to build outside a docker container, please follow the section #include "torch_tensorrt/torch_tensorrt.h" ... -// Set input datatypes. Allowerd options torch::{kFloat, kHalf, kChar, kInt32, kBool} +// Set input datatypes. Allowed options torch::{kFloat, kHalf, kChar, kInt32, kBool} // Size of input_dtypes should match number of inputs to the network. // If input_dtypes is not set, default precision follows traditional PyT / TRT rules auto input = torch_tensorrt::Input(dims, torch::kHalf); @@ -306,7 +306,7 @@ Supported Python versions: ### In Torch-TensorRT? -Thanks for wanting to contribute! There are two main ways to handle supporting a new op. Either you can write a converter for the op from scratch and register it in the NodeConverterRegistry or if you can map the op to a set of ops that already have converters you can write a graph rewrite pass which will replace your new op with an equivalent subgraph of supported ops. Its preferred to use graph rewriting because then we do not need to maintain a large library of op converters. Also do look at the various op support trackers in the [issues](https://github.com/pytorch/TensorRT/issues) for information on the support status of various operators. +Thanks for wanting to contribute! There are two main ways to handle supporting a new op. Either you can write a converter for the op from scratch and register it in the NodeConverterRegistry or if you can map the op to a set of ops that already have converters you can write a graph rewrite pass which will replace your new op with an equivalent subgraph of supported ops. It's preferred to use graph rewriting because then we do not need to maintain a large library of op converters. Also do look at the various op support trackers in the [issues](https://github.com/pytorch/TensorRT/issues) for information on the support status of various operators. ### In my application? From 08f1636e6f72c1eb21129646c4d03e43fbe7c931 Mon Sep 17 00:00:00 2001 From: "Zewen (Evan) Li" Date: Wed, 24 Apr 2024 17:13:28 -0700 Subject: [PATCH 20/20] fix: convert_module_to_trt_engine (#2728) --- docsrc/py_api/dynamo.rst | 2 + py/torch_tensorrt/_compile.py | 16 ++++-- py/torch_tensorrt/dynamo/_compiler.py | 50 ++++++++----------- ...y => test_convert_module_to_trt_engine.py} | 10 ++-- 4 files changed, 41 insertions(+), 37 deletions(-) rename tests/py/dynamo/runtime/{test_convert_method_to_trt_engine.py => test_convert_module_to_trt_engine.py} (81%) diff --git a/docsrc/py_api/dynamo.rst b/docsrc/py_api/dynamo.rst index fce5372d0e..6b4a527663 100644 --- a/docsrc/py_api/dynamo.rst +++ b/docsrc/py_api/dynamo.rst @@ -22,6 +22,8 @@ Functions .. autofunction:: export +.. autofunction:: convert_module_to_trt_engine + Classes diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 01692006a6..13bb2a585d 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections.abc import logging from enum import Enum from typing import Any, Callable, List, Optional, Sequence, Set @@ -237,8 +238,6 @@ def compile( return compiled_fx_module elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs - import collections.abc - from torch_tensorrt.dynamo.utils import prepare_inputs if not isinstance(input_list, collections.abc.Sequence): @@ -342,10 +341,19 @@ def convert_method_to_trt_engine( "convert_method_to_trt_engine call is not supported for ir=fx" ) elif target_ir == _IRType.dynamo: + # Prepare torch and torchtrt inputs + from torch_tensorrt.dynamo.utils import prepare_inputs + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + # Export the module + torchtrt_inputs = prepare_inputs(inputs) + exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) + return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] - module, + exp_program, inputs=inputs, - method_name=method_name, enabled_precisions=enabled_precisions_set, **kwargs, ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 09543a5d64..66e3ab1d61 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -422,8 +422,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: def convert_module_to_trt_engine( - module: torch.fx.GraphModule, - method_name: str = "forward", + exported_program: ExportedProgram, inputs: Optional[Sequence[Input | torch.Tensor]] = None, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] @@ -453,15 +452,15 @@ def convert_module_to_trt_engine( calibrator: object = None, allow_shape_tensors: bool = False, ) -> bytes: - """Convert a GraphModule module method to a serialized TensorRT engine + """Convert an ExportedProgram to a serialized TensorRT engine - Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings + Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings Arguments: - module (torch.fx.GraphModule): Source module + exported_program (torch.export.ExportedProgram): Source module Keyword Args: - inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. :: @@ -476,30 +475,11 @@ def convert_module_to_trt_engine( ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] - - method_name (str): Name of method to convert - input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using - torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** :: - - input_signature=([ - torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 - torch_tensorrt.Input( - min_shape=(1, 224, 224, 3), - opt_shape=(1, 512, 512, 3), - max_shape=(1, 1024, 1024, 3), - dtype=torch.int32 - format=torch.channel_last - ), # Dynamic input shape for input #2 - ], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3 - - device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: - - device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) - + enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block - torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage + torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine version_compatible (bool): Provide version forward-compatibility for engine plan files @@ -566,13 +546,25 @@ def convert_module_to_trt_engine( "dla_global_dram_size": dla_global_dram_size, } + # Decompose the exported program + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = exported_program.module() + logger.debug("Input graph: " + str(gm.graph)) + + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(input_list, device) + gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) + settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) try: - interpreter_result = interpret_module_to_result(module, input_list, settings) + interpreter_result = interpret_module_to_result(gm, input_list, settings) except UnsupportedOperatorException: logger.error( - f"Conversion of module {module} not currently fully supported or convertible!", + f"Conversion of module {gm} not currently fully supported or convertible!", exc_info=True, ) except Exception as e: diff --git a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py similarity index 81% rename from tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py rename to tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py index b10cae23fa..00b5dd8b31 100644 --- a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py +++ b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -class TestConvertMethodToTrtEngine(unittest.TestCase): +class TestConvertModuleToTrtEngine(unittest.TestCase): def test_convert_module(self): class Test(torch.nn.Module): def forward(self, a, b): @@ -18,11 +18,11 @@ def forward(self, a, b): # Create a model model = Test() - symbolic_traced_gm = torch.fx.symbolic_trace(model) + exp_program = torch.export.export(model, (input_data_0, input_data_1)) # Convert to TensorRT engine trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine( - symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1] + exp_program, inputs=(input_data_0, input_data_1) ) # Deserialize the TensorRT engine @@ -30,7 +30,9 @@ def forward(self, a, b): engine = runtime.deserialize_cuda_engine(trt_engine_str) # Inference on TRT Engine - py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"]) + py_trt_module = PythonTorchTensorRTModule( + engine, ["arg0_1", "arg1_1"], ["output0"] + ) trt_output = py_trt_module(input_data_0, input_data_1).cpu() # Inference on PyTorch model