Skip to content

Commit 06a6694

Browse files
committed
Switch Tuple[int, int, int] with List[int] to support easier the 2D case
1 parent 881565c commit 06a6694

File tree

1 file changed

+39
-34
lines changed

1 file changed

+39
-34
lines changed

torchvision/models/video/mvitv2.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from functools import partial
3-
from typing import Any, Callable, List, Optional, Sequence, Tuple, cast
3+
from typing import Any, Callable, List, Optional, Sequence, Tuple
44

55
import torch
66
import torch.fx
@@ -94,11 +94,11 @@ def __init__(
9494
self,
9595
embed_dim: int,
9696
num_heads: int,
97+
kernel_q: List[int],
98+
kernel_kv: List[int],
99+
stride_q: List[int],
100+
stride_kv: List[int],
97101
dropout: float = 0.0,
98-
kernel_q: Tuple[int, int, int] = (1, 1, 1),
99-
kernel_kv: Tuple[int, int, int] = (1, 1, 1),
100-
stride_q: Tuple[int, int, int] = (1, 1, 1),
101-
stride_kv: Tuple[int, int, int] = (1, 1, 1),
102102
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
103103
) -> None:
104104
super().__init__()
@@ -115,14 +115,14 @@ def __init__(
115115

116116
self.pool_q: Optional[nn.Module] = None
117117
if _prod(kernel_q) > 1 or _prod(stride_q) > 1:
118-
padding_q = cast(Tuple[int, int, int], tuple(int(q // 2) for q in kernel_q))
118+
padding_q = [int(q // 2) for q in kernel_q]
119119
self.pool_q = Pool(
120120
nn.Conv3d(
121121
self.head_dim,
122122
self.head_dim,
123-
kernel_q,
124-
stride=stride_q,
125-
padding=padding_q,
123+
kernel_q, # type: ignore[arg-type]
124+
stride=stride_q, # type: ignore[arg-type]
125+
padding=padding_q, # type: ignore[arg-type]
126126
groups=self.head_dim,
127127
bias=False,
128128
),
@@ -132,14 +132,14 @@ def __init__(
132132
self.pool_k: Optional[nn.Module] = None
133133
self.pool_v: Optional[nn.Module] = None
134134
if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1:
135-
padding_kv = cast(Tuple[int, int, int], tuple(int(kv // 2) for kv in kernel_kv))
135+
padding_kv = [int(kv // 2) for kv in kernel_kv]
136136
self.pool_k = Pool(
137137
nn.Conv3d(
138138
self.head_dim,
139139
self.head_dim,
140-
kernel_kv,
141-
stride=stride_kv,
142-
padding=padding_kv,
140+
kernel_kv, # type: ignore[arg-type]
141+
stride=stride_kv, # type: ignore[arg-type]
142+
padding=padding_kv, # type: ignore[arg-type]
143143
groups=self.head_dim,
144144
bias=False,
145145
),
@@ -149,9 +149,9 @@ def __init__(
149149
nn.Conv3d(
150150
self.head_dim,
151151
self.head_dim,
152-
kernel_kv,
153-
stride=stride_kv,
154-
padding=padding_kv,
152+
kernel_kv, # type: ignore[arg-type]
153+
stride=stride_kv, # type: ignore[arg-type]
154+
padding=padding_kv, # type: ignore[arg-type]
155155
groups=self.head_dim,
156156
bias=False,
157157
),
@@ -185,21 +185,23 @@ def __init__(
185185
input_channels: int,
186186
output_channels: int,
187187
num_heads: int,
188+
kernel_q: List[int],
189+
kernel_kv: List[int],
190+
stride_q: List[int],
191+
stride_kv: List[int],
188192
dropout: float = 0.0,
189193
stochastic_depth_prob: float = 0.0,
190-
kernel_q: Tuple[int, int, int] = (1, 1, 1),
191-
kernel_kv: Tuple[int, int, int] = (1, 1, 1),
192-
stride_q: Tuple[int, int, int] = (1, 1, 1),
193-
stride_kv: Tuple[int, int, int] = (1, 1, 1),
194194
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
195195
) -> None:
196196
super().__init__()
197197

198198
self.pool_skip: Optional[nn.Module] = None
199199
if _prod(stride_q) > 1:
200-
kernel_skip = cast(Tuple[int, int, int], tuple(s + 1 if s > 1 else s for s in stride_q))
201-
padding_skip = cast(Tuple[int, int, int], tuple(int(k // 2) for k in kernel_skip))
202-
self.pool_skip = Pool(nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None)
200+
kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
201+
padding_skip = [int(k // 2) for k in kernel_skip]
202+
self.pool_skip = Pool(
203+
nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None # type: ignore[arg-type]
204+
)
203205

204206
self.norm1 = norm_layer(input_channels)
205207
self.norm2 = norm_layer(input_channels)
@@ -208,11 +210,11 @@ def __init__(
208210
self.attn = MultiscaleAttention(
209211
input_channels,
210212
num_heads,
211-
dropout=dropout,
212213
kernel_q=kernel_q,
213214
kernel_kv=kernel_kv,
214215
stride_q=stride_q,
215216
stride_kv=stride_kv,
217+
dropout=dropout,
216218
norm_layer=norm_layer,
217219
)
218220
self.mlp = MLP(
@@ -270,9 +272,9 @@ def __init__(
270272
embed_channels: List[int],
271273
blocks: List[int],
272274
heads: List[int],
273-
pool_kv_stride: Tuple[int, int, int] = (1, 8, 8),
274-
pool_q_stride: Tuple[int, int, int] = (1, 2, 2),
275-
pool_kvq_kernel: Tuple[int, int, int] = (3, 3, 3),
275+
pool_kv_stride: List[int],
276+
pool_q_stride: List[int],
277+
pool_kvq_kernel: List[int],
276278
dropout: float = 0.0,
277279
attention_dropout: float = 0.0,
278280
stochastic_depth_prob: float = 0.0,
@@ -289,9 +291,9 @@ def __init__(
289291
embed_channels (list of ints): A list with the embedding dimensions of each block group.
290292
blocks (list of ints): A list with the number of blocks of each block group.
291293
heads (list of ints): A list with the number of heads of each block group.
292-
pool_kv_stride (tuple of ints): The initialize pooling stride of the first block.
293-
pool_q_stride (tuple of ints): The pooling stride which reduces q in each block group.
294-
pool_kvq_kernel (tuple of ints): The pooling kernel for the attention.
294+
pool_kv_stride (list of ints): The initiale pooling stride of the first block.
295+
pool_q_stride (list of ints): The pooling stride which reduces q in each block group.
296+
pool_kvq_kernel (list of ints): The pooling kernel for the attention.
295297
dropout (float): Dropout rate. Default: 0.0.
296298
attention_dropout (float): Attention dropout rate. Default: 0.0.
297299
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
@@ -343,12 +345,12 @@ def __init__(
343345
next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i
344346
output_channels = embed_channels[next_block_index]
345347

346-
stride_q = (1, 1, 1)
348+
stride_q = [1, 1, 1]
347349
if pool_countdown == 0:
348350
stride_q = pool_q_stride
349351
pool_countdown = blocks[next_block_index]
350352

351-
stride_kv = cast(Tuple[int, int, int], tuple(max(s // stride_q[d], 1) for d, s in enumerate(stride_kv)))
353+
stride_kv = [max(s // stride_q[d], 1) for d, s in enumerate(stride_kv)]
352354

353355
# adjust stochastic depth probability based on the depth of the stage block
354356
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
@@ -358,12 +360,12 @@ def __init__(
358360
input_channels=input_channels,
359361
output_channels=output_channels,
360362
num_heads=heads[i],
361-
dropout=attention_dropout,
362-
stochastic_depth_prob=sd_prob,
363363
kernel_q=pool_kvq_kernel,
364364
kernel_kv=pool_kvq_kernel,
365365
stride_q=stride_q,
366366
stride_kv=stride_kv,
367+
dropout=attention_dropout,
368+
stochastic_depth_prob=sd_prob,
367369
norm_layer=norm_layer,
368370
)
369371
)
@@ -437,6 +439,9 @@ def _mvitv2(
437439
embed_channels=embed_channels,
438440
blocks=blocks,
439441
heads=heads,
442+
pool_kv_stride=kwargs.pop("pool_kv_stride", [1, 8, 8]),
443+
pool_q_stride=kwargs.pop("pool_q_stride", [1, 2, 2]),
444+
pool_kvq_kernel=kwargs.pop("pool_kvq_kernel", [3, 3, 3]),
440445
stochastic_depth_prob=stochastic_depth_prob,
441446
**kwargs,
442447
)

0 commit comments

Comments
 (0)