diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 2467fdd6ae..246b7e3cb7 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -89,7 +89,9 @@ def _pretraced_backend( gm = apply_lowering_passes(gm, sample_inputs) - torchtrt_inputs = prepare_inputs(sample_inputs) + torchtrt_inputs = prepare_inputs( + sample_inputs, disable_memory_format_check=True + ) trt_compiled = compile_module( gm, torchtrt_inputs, diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index 77a66c7c6d..59ca3e3143 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -35,7 +35,16 @@ def convert_module( if not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] - output_dtypes = [output.dtype for output in module_outputs] + # Int64 outputs can sometimes be generated from within other operators + # such as aten.sum - such outputs can be truncated + output_dtypes = [] + for output in module_outputs: + if settings.truncate_long_and_double and output.dtype == torch.float64: + output_dtypes.append(torch.float32) + elif settings.truncate_long_and_double and output.dtype == torch.int64: + output_dtypes.append(torch.int32) + else: + output_dtypes.append(output.dtype) interpreter = TRTInterpreter( module, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 692a52b0df..2fcd57a7f6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -49,10 +49,8 @@ def sum( dim: Optional[Union[int, Sequence[int]]], keepdim: bool, ) -> TRTTensor: - if (isinstance(input_val, TRTTensor)) and ( - input_val.dtype == trt.int8 or input_val.dtype == trt.int32 - ): - input_val = cast_trt_tensor(ctx, input_val, trt.float32, name) + if (isinstance(input_val, TRTTensor)) and (input_val.dtype == trt.bool): + input_val = cast_trt_tensor(ctx, input_val, trt.int32, name) if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0): dim = tuple(range(len(input_val.shape))) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 97ffdb728f..8a77508014 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -31,30 +31,23 @@ def slice_op( # TODO: This should be slice not whatever is in base "of the TensorRT region!" ) - ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) - dim = get_positive_dim(dim, ranks) - dynamic_shape = has_dynamic_shape(input.shape) - if ctx.net.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - start_int = start - stop_int = stop - if stop_int == 2**63 - 1: - stop_int = input.shape[dim] - step_int = step + dim = get_positive_dim(dim, len(input.shape)) + start = get_positive_dim(start, input.shape[dim]) + stop = get_positive_dim(stop, input.shape[dim]) + + if has_dynamic_shape(input.shape): + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" + + if stop == 2**63 - 1: + stop = input.shape[dim] + start_slice = [0] * len(input.shape) - start_slice[dim] = start_int - stride_slice = [1] * len(start_slice) - stride_slice[dim] = step_int + start_slice[dim] = start + stride_slice = [1] * len(input.shape) + stride_slice[dim] = step output_shape = list(input.shape) - output_shape[dim] = math.ceil((stop_int - start_int) / step_int) + output_shape[dim] = math.ceil((stop - start) / step) return slice( ctx, target, source_ir, name, input, start_slice, output_shape, stride_slice diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 97046ba421..fa84a5d0c4 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,12 +5,12 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import torch -import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._defaults import PRECISION +import torch_tensorrt from packaging import version logger = logging.getLogger(__name__) @@ -104,17 +104,22 @@ def set_log_level(parent_logger: Any, level: Any) -> None: def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], + disable_memory_format_check: bool = False, ) -> Any: if isinstance(inputs, Input): return inputs elif isinstance(inputs, torch.Tensor): - return Input.from_tensor(inputs) + return Input.from_tensor( + inputs, disable_memory_format_check=disable_memory_format_check + ) elif isinstance(inputs, list): torchtrt_input_list = [] for input_obj in inputs: - torchtrt_input = prepare_inputs(input_obj) + torchtrt_input = prepare_inputs( + input_obj, disable_memory_format_check=disable_memory_format_check + ) torchtrt_input_list.append(torchtrt_input) return torchtrt_input_list @@ -122,7 +127,9 @@ def prepare_inputs( elif isinstance(inputs, tuple): torchtrt_inputs_tup = [] for input_obj in inputs: - torchtrt_input = prepare_inputs(input_obj) + torchtrt_input = prepare_inputs( + input_obj, disable_memory_format_check=disable_memory_format_check + ) torchtrt_inputs_tup.append(torchtrt_input) return tuple(torchtrt_inputs_tup) @@ -131,7 +138,9 @@ def prepare_inputs( torchtrt_inputs_dict: Dict[Any, Any] = dict() for key, input_obj in inputs.items(): - torchtrt_input = prepare_inputs(input_obj) + torchtrt_input = prepare_inputs( + input_obj, disable_memory_format_check=disable_memory_format_check + ) torchtrt_inputs_dict[key] = torchtrt_input return torchtrt_inputs_dict diff --git a/tests/py/dynamo/conversion/test_slice_aten.py b/tests/py/dynamo/conversion/test_slice_aten.py index 60492aac62..8c0d6dae42 100644 --- a/tests/py/dynamo/conversion/test_slice_aten.py +++ b/tests/py/dynamo/conversion/test_slice_aten.py @@ -1,38 +1,20 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests + from torch_tensorrt import Input from .harness import DispatchTestCase -class TestSelectConverterImplicitBatch(DispatchTestCase): +class TestSelectConverter(DispatchTestCase): @parameterized.expand( [ ("select_dim_start_stop_step", 0, 0, 7, 2), - ] - ) - def test_slice(self, _, dim, start, stop, step): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) - return out - - input = [torch.randn(10, 2, 3, 1)] - self.run_test( - TestModule(), - input, - ) - - -class TestSelectConverterExplicitBatch(DispatchTestCase): - @parameterized.expand( - [ - ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step_offset", 1, 0, 7, 2), ("select_dim_start_stop_step_exact", 1, 0, 10, 2), + ("select_dim_start_stop_step_negatives", -3, -2, -1, 1), + ("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), ] ) def test_slice(self, _, dim, start, stop, step): diff --git a/tests/py/dynamo/conversion/test_sum_aten.py b/tests/py/dynamo/conversion/test_sum_aten.py index b327f6ac18..999b8b4997 100644 --- a/tests/py/dynamo/conversion/test_sum_aten.py +++ b/tests/py/dynamo/conversion/test_sum_aten.py @@ -70,8 +70,8 @@ def forward(self, x): @parameterized.expand( [ - ((3, 2, 4), 1, True, torch.int, 0, 5), - ((2, 3, 4, 5), None, True, torch.int, -10, 10), + ((3, 2, 4), 1, True, torch.int32, 0, 5), + ((2, 3, 4, 5), None, True, torch.int32, -10, 10), ((2, 3, 4, 5), 2, False, torch.int32, -5, 0), ((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5), ] @@ -85,16 +85,18 @@ def forward(self, x): self.run_test( Sum(), inputs, - check_dtype=False, + output_dtypes=[torch.int32], ) @parameterized.expand( [ - ((1, 2, 4), [], True, torch.int, 0, 5), - ((3, 2, 4), [1], True, torch.int, 0, 5), - ((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10), + ((1, 2, 4), [], True, torch.int32, 0, 5), + ((3, 2, 4), [1], True, torch.int32, 0, 5), + ((2, 1, 4, 5), [0, 3], True, torch.int32, -10, 10), ((2, 3, 4, 5), None, False, torch.int32, -5, 0), ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5), + ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.bool, 0, 2), + ((4, 7, 1, 5), None, True, torch.bool, 0, 2), ] ) def test_sum_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high): @@ -106,7 +108,7 @@ def forward(self, x): self.run_test( Sum(), inputs, - check_dtype=False, + output_dtypes=[torch.int32], )