Skip to content

Commit 350e207

Browse files
committed
fix: Add support for default dimension in aten.cat
- Add default `dim=0` to concatenation operator for use cases which do not have a specific concatenation dimension specified - T5 encounters this error during compilation - Add test cases to elicit error with default dimension
1 parent b3f433a commit 350e207

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def aten_ops_cat(
358358
) -> Union[TRTTensor, Sequence[TRTTensor]]:
359359
kwargs_new = {
360360
"tensors": args[0],
361-
"dim": args[1],
361+
"dim": args[1] if len(args) >= 2 else 0,
362362
}
363363
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)
364364

py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase):
99
@parameterized.expand(
1010
[
1111
("pos", 1),
12-
# ("neg", -2), #Dynamo tracer issue
12+
("neg", -2),
1313
]
1414
)
1515
def test_cat(self, _, dim):
@@ -27,7 +27,7 @@ def forward(self, x, y, z):
2727
@parameterized.expand(
2828
[
2929
("pos", 1),
30-
# ("neg", -2), #Dynamo tracer issue
30+
("neg", -2),
3131
]
3232
)
3333
def test_cat_dynamic_shape(self, _, dim):
@@ -53,6 +53,41 @@ def forward(self, x, y):
5353
expected_ops={torch.ops.aten.cat.default},
5454
)
5555

56+
def test_cat_no_dim(self):
57+
class Cat(nn.Module):
58+
def forward(self, x, y, z):
59+
return torch.cat((x, y, z))
60+
61+
inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)]
62+
self.run_test(
63+
Cat(),
64+
inputs,
65+
expected_ops={torch.ops.aten.cat.default},
66+
)
67+
68+
def test_cat_dynamic_shape_no_dim(self):
69+
class Cat(nn.Module):
70+
def forward(self, x, y):
71+
return torch.cat((x, y))
72+
73+
input_specs = [
74+
InputTensorSpec(
75+
shape=(-1, 16, 3),
76+
dtype=torch.float32,
77+
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
78+
),
79+
InputTensorSpec(
80+
shape=(-1, 16, 3),
81+
dtype=torch.float32,
82+
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
83+
),
84+
]
85+
self.run_test_with_dynamic_shape(
86+
Cat(),
87+
input_specs,
88+
expected_ops={torch.ops.aten.cat.default},
89+
)
90+
5691

5792
if __name__ == "__main__":
5893
run_tests()

0 commit comments

Comments
 (0)