@@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase):
9
9
@parameterized .expand (
10
10
[
11
11
("pos" , 1 ),
12
- # ("neg", -2), #dim can not have dynamic input
12
+ ("neg" , - 2 ),
13
13
]
14
14
)
15
15
def test_cat (self , _ , dim ):
@@ -27,7 +27,7 @@ def forward(self, x, y, z):
27
27
@parameterized .expand (
28
28
[
29
29
("pos" , 1 ),
30
- # ("neg", -2), #dim can not have dynamic input
30
+ ("neg" , - 2 ),
31
31
]
32
32
)
33
33
def test_cat_dynamic_shape (self , _ , dim ):
@@ -53,6 +53,41 @@ def forward(self, x, y):
53
53
expected_ops = {torch .ops .aten .cat .default },
54
54
)
55
55
56
+ def test_cat_no_dim (self ):
57
+ class Cat (nn .Module ):
58
+ def forward (self , x , y , z ):
59
+ return torch .cat ((x , y , z ))
60
+
61
+ inputs = [torch .randn (2 , 1 , 3 ), torch .randn (1 , 1 , 3 ), torch .randn (3 , 1 , 3 )]
62
+ self .run_test (
63
+ Cat (),
64
+ inputs ,
65
+ expected_ops = {torch .ops .aten .cat .default },
66
+ )
67
+
68
+ def test_cat_dynamic_shape_no_dim (self ):
69
+ class Cat (nn .Module ):
70
+ def forward (self , x , y ):
71
+ return torch .cat ((x , y ))
72
+
73
+ input_specs = [
74
+ InputTensorSpec (
75
+ shape = (- 1 , 16 , 3 ),
76
+ dtype = torch .float32 ,
77
+ shape_ranges = [((2 , 16 , 3 ), (3 , 16 , 3 ), (32 , 16 , 3 ))],
78
+ ),
79
+ InputTensorSpec (
80
+ shape = (- 1 , 16 , 3 ),
81
+ dtype = torch .float32 ,
82
+ shape_ranges = [((2 , 16 , 3 ), (3 , 16 , 3 ), (32 , 16 , 3 ))],
83
+ ),
84
+ ]
85
+ self .run_test_with_dynamic_shape (
86
+ Cat (),
87
+ input_specs ,
88
+ expected_ops = {torch .ops .aten .cat .default },
89
+ )
90
+
56
91
57
92
if __name__ == "__main__" :
58
93
run_tests ()
0 commit comments