1
1
import math
2
2
from functools import partial
3
- from typing import Any , Callable , List , Optional , Sequence , Tuple , cast
3
+ from typing import Any , Callable , List , Optional , Sequence , Tuple
4
4
5
5
import torch
6
6
import torch .fx
@@ -94,11 +94,11 @@ def __init__(
94
94
self ,
95
95
embed_dim : int ,
96
96
num_heads : int ,
97
+ kernel_q : List [int ],
98
+ kernel_kv : List [int ],
99
+ stride_q : List [int ],
100
+ stride_kv : List [int ],
97
101
dropout : float = 0.0 ,
98
- kernel_q : Tuple [int , int , int ] = (1 , 1 , 1 ),
99
- kernel_kv : Tuple [int , int , int ] = (1 , 1 , 1 ),
100
- stride_q : Tuple [int , int , int ] = (1 , 1 , 1 ),
101
- stride_kv : Tuple [int , int , int ] = (1 , 1 , 1 ),
102
102
norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
103
103
) -> None :
104
104
super ().__init__ ()
@@ -115,14 +115,14 @@ def __init__(
115
115
116
116
self .pool_q : Optional [nn .Module ] = None
117
117
if _prod (kernel_q ) > 1 or _prod (stride_q ) > 1 :
118
- padding_q = cast ( Tuple [int , int , int ], tuple ( int ( q // 2 ) for q in kernel_q ))
118
+ padding_q = [int ( q // 2 ) for q in kernel_q ]
119
119
self .pool_q = Pool (
120
120
nn .Conv3d (
121
121
self .head_dim ,
122
122
self .head_dim ,
123
- kernel_q ,
124
- stride = stride_q ,
125
- padding = padding_q ,
123
+ kernel_q , # type: ignore[arg-type]
124
+ stride = stride_q , # type: ignore[arg-type]
125
+ padding = padding_q , # type: ignore[arg-type]
126
126
groups = self .head_dim ,
127
127
bias = False ,
128
128
),
@@ -132,14 +132,14 @@ def __init__(
132
132
self .pool_k : Optional [nn .Module ] = None
133
133
self .pool_v : Optional [nn .Module ] = None
134
134
if _prod (kernel_kv ) > 1 or _prod (stride_kv ) > 1 :
135
- padding_kv = cast ( Tuple [int , int , int ], tuple ( int ( kv // 2 ) for kv in kernel_kv ))
135
+ padding_kv = [int ( kv // 2 ) for kv in kernel_kv ]
136
136
self .pool_k = Pool (
137
137
nn .Conv3d (
138
138
self .head_dim ,
139
139
self .head_dim ,
140
- kernel_kv ,
141
- stride = stride_kv ,
142
- padding = padding_kv ,
140
+ kernel_kv , # type: ignore[arg-type]
141
+ stride = stride_kv , # type: ignore[arg-type]
142
+ padding = padding_kv , # type: ignore[arg-type]
143
143
groups = self .head_dim ,
144
144
bias = False ,
145
145
),
@@ -149,9 +149,9 @@ def __init__(
149
149
nn .Conv3d (
150
150
self .head_dim ,
151
151
self .head_dim ,
152
- kernel_kv ,
153
- stride = stride_kv ,
154
- padding = padding_kv ,
152
+ kernel_kv , # type: ignore[arg-type]
153
+ stride = stride_kv , # type: ignore[arg-type]
154
+ padding = padding_kv , # type: ignore[arg-type]
155
155
groups = self .head_dim ,
156
156
bias = False ,
157
157
),
@@ -185,21 +185,23 @@ def __init__(
185
185
input_channels : int ,
186
186
output_channels : int ,
187
187
num_heads : int ,
188
+ kernel_q : List [int ],
189
+ kernel_kv : List [int ],
190
+ stride_q : List [int ],
191
+ stride_kv : List [int ],
188
192
dropout : float = 0.0 ,
189
193
stochastic_depth_prob : float = 0.0 ,
190
- kernel_q : Tuple [int , int , int ] = (1 , 1 , 1 ),
191
- kernel_kv : Tuple [int , int , int ] = (1 , 1 , 1 ),
192
- stride_q : Tuple [int , int , int ] = (1 , 1 , 1 ),
193
- stride_kv : Tuple [int , int , int ] = (1 , 1 , 1 ),
194
194
norm_layer : Callable [..., nn .Module ] = nn .LayerNorm ,
195
195
) -> None :
196
196
super ().__init__ ()
197
197
198
198
self .pool_skip : Optional [nn .Module ] = None
199
199
if _prod (stride_q ) > 1 :
200
- kernel_skip = cast (Tuple [int , int , int ], tuple (s + 1 if s > 1 else s for s in stride_q ))
201
- padding_skip = cast (Tuple [int , int , int ], tuple (int (k // 2 ) for k in kernel_skip ))
202
- self .pool_skip = Pool (nn .MaxPool3d (kernel_skip , stride = stride_q , padding = padding_skip ), None )
200
+ kernel_skip = [s + 1 if s > 1 else s for s in stride_q ]
201
+ padding_skip = [int (k // 2 ) for k in kernel_skip ]
202
+ self .pool_skip = Pool (
203
+ nn .MaxPool3d (kernel_skip , stride = stride_q , padding = padding_skip ), None # type: ignore[arg-type]
204
+ )
203
205
204
206
self .norm1 = norm_layer (input_channels )
205
207
self .norm2 = norm_layer (input_channels )
@@ -208,11 +210,11 @@ def __init__(
208
210
self .attn = MultiscaleAttention (
209
211
input_channels ,
210
212
num_heads ,
211
- dropout = dropout ,
212
213
kernel_q = kernel_q ,
213
214
kernel_kv = kernel_kv ,
214
215
stride_q = stride_q ,
215
216
stride_kv = stride_kv ,
217
+ dropout = dropout ,
216
218
norm_layer = norm_layer ,
217
219
)
218
220
self .mlp = MLP (
@@ -270,9 +272,9 @@ def __init__(
270
272
embed_channels : List [int ],
271
273
blocks : List [int ],
272
274
heads : List [int ],
273
- pool_kv_stride : Tuple [int , int , int ] = ( 1 , 8 , 8 ) ,
274
- pool_q_stride : Tuple [int , int , int ] = ( 1 , 2 , 2 ) ,
275
- pool_kvq_kernel : Tuple [int , int , int ] = ( 3 , 3 , 3 ) ,
275
+ pool_kv_stride : List [int ] ,
276
+ pool_q_stride : List [int ] ,
277
+ pool_kvq_kernel : List [int ] ,
276
278
dropout : float = 0.0 ,
277
279
attention_dropout : float = 0.0 ,
278
280
stochastic_depth_prob : float = 0.0 ,
@@ -289,9 +291,9 @@ def __init__(
289
291
embed_channels (list of ints): A list with the embedding dimensions of each block group.
290
292
blocks (list of ints): A list with the number of blocks of each block group.
291
293
heads (list of ints): A list with the number of heads of each block group.
292
- pool_kv_stride (tuple of ints): The initialize pooling stride of the first block.
293
- pool_q_stride (tuple of ints): The pooling stride which reduces q in each block group.
294
- pool_kvq_kernel (tuple of ints): The pooling kernel for the attention.
294
+ pool_kv_stride (list of ints): The initiale pooling stride of the first block.
295
+ pool_q_stride (list of ints): The pooling stride which reduces q in each block group.
296
+ pool_kvq_kernel (list of ints): The pooling kernel for the attention.
295
297
dropout (float): Dropout rate. Default: 0.0.
296
298
attention_dropout (float): Attention dropout rate. Default: 0.0.
297
299
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
@@ -343,12 +345,12 @@ def __init__(
343
345
next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i
344
346
output_channels = embed_channels [next_block_index ]
345
347
346
- stride_q = ( 1 , 1 , 1 )
348
+ stride_q = [ 1 , 1 , 1 ]
347
349
if pool_countdown == 0 :
348
350
stride_q = pool_q_stride
349
351
pool_countdown = blocks [next_block_index ]
350
352
351
- stride_kv = cast ( Tuple [ int , int , int ], tuple ( max (s // stride_q [d ], 1 ) for d , s in enumerate (stride_kv )))
353
+ stride_kv = [ max (s // stride_q [d ], 1 ) for d , s in enumerate (stride_kv )]
352
354
353
355
# adjust stochastic depth probability based on the depth of the stage block
354
356
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0 )
@@ -358,12 +360,12 @@ def __init__(
358
360
input_channels = input_channels ,
359
361
output_channels = output_channels ,
360
362
num_heads = heads [i ],
361
- dropout = attention_dropout ,
362
- stochastic_depth_prob = sd_prob ,
363
363
kernel_q = pool_kvq_kernel ,
364
364
kernel_kv = pool_kvq_kernel ,
365
365
stride_q = stride_q ,
366
366
stride_kv = stride_kv ,
367
+ dropout = attention_dropout ,
368
+ stochastic_depth_prob = sd_prob ,
367
369
norm_layer = norm_layer ,
368
370
)
369
371
)
@@ -437,6 +439,9 @@ def _mvitv2(
437
439
embed_channels = embed_channels ,
438
440
blocks = blocks ,
439
441
heads = heads ,
442
+ pool_kv_stride = kwargs .pop ("pool_kv_stride" , [1 , 8 , 8 ]),
443
+ pool_q_stride = kwargs .pop ("pool_q_stride" , [1 , 2 , 2 ]),
444
+ pool_kvq_kernel = kwargs .pop ("pool_kvq_kernel" , [3 , 3 , 3 ]),
440
445
stochastic_depth_prob = stochastic_depth_prob ,
441
446
** kwargs ,
442
447
)
0 commit comments