1
- from typing import Iterable , List , NamedTuple , Sequence , Tuple
1
+ from typing import Iterable , List , NamedTuple , Optional , Sequence , Tuple
2
2
3
3
import torch
4
4
5
5
from .types import Shape , ShapeRange
6
6
from .utils import get_dynamic_dims
7
7
8
8
9
+ def generate_input_specs (
10
+ inputs , lower_setting , additional_inputs = None , fixed_shape = False
11
+ ):
12
+ # AIT lower setting doesn't have explicit_batch_dimension field and
13
+ # we just return None.
14
+ if not hasattr (lower_setting , "explicit_batch_dimension" ):
15
+ return None
16
+
17
+ if not lower_setting .explicit_batch_dimension or fixed_shape :
18
+ return InputTensorSpec .from_tensors (inputs )
19
+
20
+ # If we don't have additional inputs, we assume the first dimension
21
+ # is the dynamic batch dimension. Otherwise, we use the additional
22
+ # inputs to determine the batch dimension.
23
+ if additional_inputs is None :
24
+ return InputTensorSpec .from_tensors_with_dynamic_batch_size (
25
+ inputs ,
26
+ (
27
+ 0 ,
28
+ lower_setting .max_batch_size ,
29
+ lower_setting .max_batch_size ,
30
+ ),
31
+ lower_setting .opt_profile_replica ,
32
+ )
33
+ else :
34
+ batch_dims = []
35
+
36
+ for i , j in zip (inputs , additional_inputs ):
37
+ found_batch_dim = False
38
+
39
+ for idx , values in enumerate (zip (i .shape , j .shape )):
40
+ if values [0 ] != values [1 ]:
41
+ assert (
42
+ found_batch_dim is False
43
+ ), f"We've already found a batch dim, { i .shape } , { j .shape } ."
44
+ batch_dims .append (idx )
45
+ found_batch_dim = True
46
+
47
+ if not found_batch_dim :
48
+ raise RuntimeError (
49
+ f"Failed to find batch dimension because shapes are the same, { i .shape } "
50
+ )
51
+
52
+ return InputTensorSpec .from_tensors_with_dynamic_batch_size (
53
+ inputs ,
54
+ (
55
+ 0 ,
56
+ lower_setting .max_batch_size ,
57
+ lower_setting .max_batch_size ,
58
+ ),
59
+ lower_setting .opt_profile_replica ,
60
+ batch_dims ,
61
+ )
62
+
63
+
9
64
class InputTensorSpec (NamedTuple ):
10
65
"""
11
66
This class contains the information of a input tensor.
@@ -70,6 +125,7 @@ def from_tensors_with_dynamic_batch_size(
70
125
tensors : Sequence [torch .Tensor ],
71
126
batch_size_range : Tuple [int , int , int ],
72
127
opt_profile_replica : int = 1 ,
128
+ batch_dims : Optional [List [int ]] = None ,
73
129
) -> List ["InputTensorSpec" ]:
74
130
"""
75
131
Produce a list of InputTenosrSpec named tuples which would contain
@@ -83,20 +139,30 @@ def from_tensors_with_dynamic_batch_size(
83
139
the smallest batch size allowed. The second integer indiceates
84
140
the batch size that we'll optimize for. The third integer indicates
85
141
the largest batch size allowed.
142
+ opt_profile_replica (int): If dynamic shape is enabled, each execution
143
+ context requires a different optimization profile. This arg determines
144
+ how many optimization profile replicas we want to produce.
145
+ batch_dims (Optional[List[int]]): The batch dim might not be the leading dim
146
+ and allow user to specify the batch dims using this arg. Default we treat
147
+ dim 0 as the batch dim.
86
148
87
149
Returns:
88
150
A list of InputTensorSpec named tuples with dynamic ranges.
89
151
"""
152
+ if batch_dims is None :
153
+ batch_dims = [0 ] * len (tensors )
154
+
90
155
input_specs = []
91
- batch_size = tensors [0 ].size (0 )
156
+ batch_size = tensors [0 ].size (batch_dims [ 0 ] )
92
157
93
158
for i , tensor in enumerate (tensors ):
159
+ batch_dim = batch_dims [i ]
94
160
assert batch_size == tensor .size (
95
- 0
161
+ batch_dim
96
162
), f"The { i } th tensor (shape: { tensor .shape } ) doesn't have the correct batch size: { batch_size } ."
97
163
shape = list (tensor .shape )
98
- shape [0 ] = - 1
99
- shape_ranges : List [ShapeRange ] = [tuple (tuple ([ bs ] + shape [1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
164
+ shape [batch_dim ] = - 1
165
+ shape_ranges : List [ShapeRange ] = [tuple (tuple (shape [ 0 : batch_dim ] + [ bs ] + shape [batch_dim + 1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
100
166
input_specs .append (
101
167
cls (tuple (shape ), tensor .dtype , tensor .device , shape_ranges )
102
168
)
0 commit comments