Skip to content

Commit 9d52c6a

Browse files
committed
fix a squeeze bug
1 parent b2ac5f0 commit 9d52c6a

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re
3-
from typing import List, Optional
3+
from typing import Any, List, Optional, Tuple
44

55
import tensorrt as trt
66
import torch
@@ -157,3 +157,24 @@ def broadcastable(
157157
if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1):
158158
return False
159159
return True
160+
161+
162+
def extend_attr_to_tuple(
163+
val: Any,
164+
num_elem: int,
165+
) -> Tuple[Any, ...]:
166+
"""
167+
If `val` is not a tuple or a list, then we make a tuple of size `num_elem` by
168+
replicating `val` `num_elem` times.
169+
170+
Args:
171+
val (Any): Value that we want to process.
172+
173+
Returns:
174+
A tuple.
175+
"""
176+
if not isinstance(val, (tuple, list)):
177+
val = (val,) * num_elem
178+
if isinstance(val, list):
179+
val = tuple(val)
180+
return val

py/torch_tensorrt/dynamo/conversion/impl/conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo.conversion import aten_ops_converters
10+
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
1011
from torch_tensorrt.fx.converters.converter_utils import (
1112
SourceIR,
12-
extend_attr_to_tuple,
1313
get_dyn_range,
1414
get_trt_tensor,
1515
has_dynamic_shape,

py/torch_tensorrt/dynamo/conversion/impl/squeeze.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ def squeeze(
1818
input: TRTTensor,
1919
dim: Optional[Any] = None,
2020
) -> TRTTensor:
21-
if not isinstance(input, TRTTensor):
22-
raise RuntimeError(
23-
f"squeeze received input {input} that is not part "
24-
"of the TensorRT region!"
25-
)
2621
dims = []
2722
if dim is not None:
2823
if isinstance(dim, int):
@@ -35,6 +30,7 @@ def squeeze(
3530
# dim, which is a very rare case. For now we just claim not supporting dim=None.
3631
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."
3732

33+
new_dims = []
3834
for dim in dims:
3935
dim = get_positive_dim(
4036
dim,
@@ -48,13 +44,14 @@ def squeeze(
4844
assert (
4945
len(get_dynamic_dims(input.shape)) <= 1
5046
), "Currently more than one dynamic dim for input to squeeze is not supported."
47+
new_dims.append(dim)
5148

5249
output_shape = []
5350
for i, s in enumerate(input.shape):
54-
if (i in dims) and s == 1:
51+
if (i in new_dims) and s == 1:
5552
continue
5653
output_shape.append(s)
5754
layer = network.add_shuffle(input)
5855
layer.reshape_dims = tuple(output_shape)
59-
set_layer_name(layer, target, name)
56+
set_layer_name(layer, target, name, source_ir)
6057
return layer.get_output(0)

0 commit comments

Comments
 (0)