From 7315bb736beecccb935b51d3433b3ec379d5003e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:06:58 -0700 Subject: [PATCH] fix: Remove references to implicit batch for TRT 10 --- .../dynamo/conversion/impl/normalization/ops.py | 9 +++------ .../dynamo/conversion/impl/select.py | 15 ++++----------- .../dynamo/conversion/impl/squeeze.py | 5 +---- .../dynamo/conversion/impl/unsqueeze.py | 10 +--------- 4 files changed, 9 insertions(+), 30 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index bbe566d0b7..b1e4fbf24c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -58,7 +58,7 @@ def batch_norm( # For BatchNorm1d, reshape 1d to 2d output_shape = input.shape - if not ctx.net.has_implicit_batch_dimension and len(input.shape) < 4: + if len(input.shape) < 4: assert ( len(get_dynamic_dims(input.shape)) <= 1 ), "BatchNorm1D with more than one dynamic dims is not currently supported." @@ -75,7 +75,7 @@ def batch_norm( output = layer.get_output(0) # For BatchNorm1d, reshape output back to 1d - if not ctx.net.has_implicit_batch_dimension and len(output_shape) < 4: + if len(output_shape) < 4: output = impl.shuffle.reshape( ctx, target, @@ -411,7 +411,7 @@ def softmax( input: TRTTensor, dim: Optional[Any] = None, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) + input_ranks = len(input.shape) if not isinstance(input, TRTTensor): raise RuntimeError( @@ -433,9 +433,6 @@ def get_softmax_dim(ndim: int) -> int: dim = cast(int, dim) dim = get_positive_dim(dim, input_ranks) - if ctx.net.has_implicit_batch_dimension: - assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." - dim -= 1 layer = ctx.net.add_softmax(input) layer.axes = 1 << dim diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6f827de2eb..2ec6420e0b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -40,19 +40,12 @@ def select( "of the TensorRT region!" ) - ranks = len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0) + ranks = len(input.shape) dim = get_positive_dim(cast(int, dim), ranks) dynamic_shape = has_dynamic_shape(input.shape) - if ctx.net.has_implicit_batch_dimension: - if dim == 0: - raise RuntimeError( - f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" - ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't select on negative shape dimension!" index = index if index >= input.shape[dim]: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index cde4fdd90d..45bdefcd80 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -32,11 +32,8 @@ def squeeze( for dim in dims: dim = get_positive_dim( dim, - len(input.shape) + (1 if ctx.net.has_implicit_batch_dimension else 0), + len(input.shape), ) - if ctx.net.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 assert input.shape[dim] != -1, "We don't support squeeze dynamic dim." assert ( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index ce893f8d5b..d056b8f0e8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -29,17 +29,9 @@ def unsqueeze( dim = cast(int, dim) - input_shape_size = ( - len(input_val.shape) + 1 - if ctx.net.has_implicit_batch_dimension - else len(input_val.shape) - ) + input_shape_size = len(input_val.shape) dim = get_positive_dim(dim, input_shape_size + 1) - if ctx.net.has_implicit_batch_dimension: - assert dim != 0 - dim -= 1 - assert ( len(get_dynamic_dims(input_val.shape)) <= 1 ), "Currently we don't support unsqueeze with more than one dynamic dims."