From 69095dd10cccd24c6f7aeda48d5306850174aabc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 7 Jun 2022 13:28:02 +0100 Subject: [PATCH 1/7] Adding MViT v2 architecture (#6105) * Adding mvitv2 architecture * Fixing memory issues on tests and minor refactorings. * Adding input validation * Adding docs and minor refactoring * Add `min_temporal_size` in the supported meta-data. * Switch Tuple[int, int, int] with List[int] to support easier the 2D case * Adding more docs and references * Change naming conventions of classes to follow the same pattern as MobileNetV3 * Fix test breakage. * Update todos * Performance optimizations. --- docs/source/models.rst | 1 + docs/source/models/video_mvitv2.rst | 28 + .../ModelTester.test_mvit_v2_b_expect.pkl | Bin 0 -> 939 bytes .../ModelTester.test_mvit_v2_s_expect.pkl | Bin 0 -> 939 bytes .../ModelTester.test_mvit_v2_t_expect.pkl | Bin 0 -> 939 bytes test/test_extended_models.py | 1 + test/test_models.py | 12 + torchvision/models/video/__init__.py | 1 + torchvision/models/video/mvitv2.py | 588 ++++++++++++++++++ 9 files changed, 631 insertions(+) create mode 100644 docs/source/models/video_mvitv2.rst create mode 100644 test/expect/ModelTester.test_mvit_v2_b_expect.pkl create mode 100644 test/expect/ModelTester.test_mvit_v2_s_expect.pkl create mode 100644 test/expect/ModelTester.test_mvit_v2_t_expect.pkl create mode 100644 torchvision/models/video/mvitv2.py diff --git a/docs/source/models.rst b/docs/source/models.rst index b549c25bf94..b6f11ae9db0 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -459,6 +459,7 @@ pre-trained weights: .. toctree:: :maxdepth: 1 + models/video_mvitv2 models/video_resnet | diff --git a/docs/source/models/video_mvitv2.rst b/docs/source/models/video_mvitv2.rst new file mode 100644 index 00000000000..e9ad556ded7 --- /dev/null +++ b/docs/source/models/video_mvitv2.rst @@ -0,0 +1,28 @@ +Video ResNet +============ + +.. currentmodule:: torchvision.models.video + +The MViT V2 model is based on the +`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection +`__ and `Multiscale Vision Transformers +`__ papers. + + +Model builders +-------------- + +The following model builders can be used to instantiate a MViTV2 model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.video.MViTV2`` base class. Please refer to the `source +code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + mvit_v2_t + mvit_v2_s + mvit_v2_b diff --git a/test/expect/ModelTester.test_mvit_v2_b_expect.pkl b/test/expect/ModelTester.test_mvit_v2_b_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a39dd2e5754797feaba9c8e10f103900e77a6c98 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~q{x-o9I|2lkzD+hTX@zM=hbeMS2krpWykA*=1f#ZC6Fv3P1Xt$D^ihs=E2 zW`_-Sf{`oj42s-r0v0Ijy_7R`-^Svr`@%L}-M>0=;r_f|{`;4hEV36gkliQjX|(^- z{}j9TRebi-zGm&4aZJqa|4UW-CAa$a1xp;-7hbQl{{VOXo&qC>yGd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=U>o(!Q?udi&G=JhWAO>uf(|x4C_RL5lqwr#M^7!iD>e*~i%}|Lr5t{wZNh`=4ETzVD5R;(ocL$#&@lHxrBfrgKjF7X$>`FDPo*)9`=wZV%Q4doC~@ z-FqZ*mE9(7r~S2w#`dv4U)$9rKi*sL_03)z7r%XC{}=CV^XJ#89K2T`Q(|Et_1TdsP7;2j35f0CXwS%03?9|&{HV7Ze&04q3C=C?(c~y%Ind!t_GJ zAi$fAO$Vw-j#(G39F&+r07h?za2Y0nJqhwI8z^ructRC`GC_bhD;r3R83;k@A!-5R CO8RjC literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_mvit_v2_t_expect.pkl b/test/expect/ModelTester.test_mvit_v2_t_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..384fe05b50c9fc4fcc92c39f62d6dc299282ea9f GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK630<(z3r5iKXw5nDSNkg`|tOO6S5UQG;`mh6C3Qgwkp`asd~Gw?Gu~5$V%@0 z)1L3zJ0pDezJEt<+x@X$yN84Iq8&qpo^{*vpL=aOH`%R;+_rC~v4q`}y$9@UPA;%B zIF)JNakO_oLzaP++ea(=yx6<8M<=xIXIXh@FRRlcJ9br}{jyvRb`BqV_d0m%*oty8 z?AP3?X|FYdbAQF&{reA{{bTz)nA5IrVUm6E62|>)wUzc2_jm5=Ygu5|5M;YYBYK~G z(fSK^e{}cl1%=j*hC4bZfFT9KxI>Gd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P int: + product = 1 + for v in s: + product *= v + return product + + +def _unsqueeze(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + tensor_dim = x.dim() + if tensor_dim == 3: + x = x.unsqueeze(1) + elif tensor_dim != 4: + raise ValueError(f"Unsupported input dimension {x.shape}") + return x, tensor_dim + + +def _squeeze(x: torch.Tensor, tensor_dim: int) -> torch.Tensor: + if tensor_dim == 3: + x = x.squeeze(1) + return x + + +torch.fx.wrap("_unsqueeze") +torch.fx.wrap("_squeeze") + + +class Pool(nn.Module): + def __init__( + self, + pool: nn.Module, + norm: Optional[nn.Module], + activation: Optional[nn.Module] = None, + norm_before_pool: bool = False, + ) -> None: + super().__init__() + self.pool = pool + layers = [] + if norm is not None: + layers.append(norm) + if activation is not None: + layers.append(activation) + self.norm_act = nn.Sequential(*layers) if layers else None + self.norm_before_pool = norm_before_pool + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x, tensor_dim = _unsqueeze(x) + + # Separate the class token and reshape the input + class_token, x = torch.tensor_split(x, indices=(1,), dim=2) + x = x.transpose(2, 3) + B, N, C = x.shape[:3] + x = x.reshape((B * N, C) + thw).contiguous() + + # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference + if self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + # apply the pool on the input and add back the token + x = self.pool(x) + T, H, W = x.shape[2:] + x = x.reshape(B, N, C, -1).transpose(2, 3) + x = torch.cat((class_token, x), dim=2) + + if not self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + x = _squeeze(x, tensor_dim) + return x, (T, H, W) + + +class MultiscaleAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + kernel_q: List[int], + kernel_kv: List[int], + stride_q: List[int], + stride_kv: List[int], + dropout: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaler = 1.0 / math.sqrt(self.head_dim) + + self.qkv = nn.Linear(embed_dim, 3 * embed_dim) + layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] + if dropout > 0.0: + layers.append(nn.Dropout(dropout, inplace=True)) + self.project = nn.Sequential(*layers) + + self.pool_q: Optional[nn.Module] = None + if _prod(kernel_q) > 1 or _prod(stride_q) > 1: + padding_q = [int(q // 2) for q in kernel_q] + self.pool_q = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_q, # type: ignore[arg-type] + stride=stride_q, # type: ignore[arg-type] + padding=padding_q, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + self.pool_k: Optional[nn.Module] = None + self.pool_v: Optional[nn.Module] = None + if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1: + padding_kv = [int(kv // 2) for kv in kernel_kv] + self.pool_k = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + self.pool_v = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + B, N, C = x.shape + q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) + + if self.pool_k is not None: + k = self.pool_k(k, thw)[0] + if self.pool_v is not None: + v = self.pool_v(v, thw)[0] + if self.pool_q is not None: + q, thw = self.pool_q(q, thw) + + attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) + attn = attn.softmax(dim=-1) + + x = torch.matmul(attn, v).add_(q) + x = x.transpose(1, 2).reshape(B, -1, C) + x = self.project(x) + + return x, thw + + +class MultiscaleBlock(nn.Module): + def __init__( + self, + input_channels: int, + output_channels: int, + num_heads: int, + kernel_q: List[int], + kernel_kv: List[int], + stride_q: List[int], + stride_kv: List[int], + dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + + self.pool_skip: Optional[nn.Module] = None + if _prod(stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + padding_skip = [int(k // 2) for k in kernel_skip] + self.pool_skip = Pool( + nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None # type: ignore[arg-type] + ) + + self.norm1 = norm_layer(input_channels) + self.norm2 = norm_layer(input_channels) + self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) + + self.attn = MultiscaleAttention( + input_channels, + num_heads, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + dropout=dropout, + norm_layer=norm_layer, + ) + self.mlp = MLP( + input_channels, + [4 * input_channels, output_channels], + activation_layer=nn.GELU, + dropout=dropout, + inplace=None, + ) + + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + self.project: Optional[nn.Module] = None + if input_channels != output_channels: + self.project = nn.Linear(input_channels, output_channels) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] + + x = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) + x, thw = self.attn(x, thw) + x = x_skip + self.stochastic_depth(x) + + x_norm = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) + x_proj = x if self.project is None else self.project(x_norm) + + return x_proj + self.stochastic_depth(self.mlp(x_norm)), thw + + +class PositionalEncoding(nn.Module): + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int) -> None: + super().__init__() + self.spatial_size = spatial_size + self.temporal_size = temporal_size + + self.class_token = nn.Parameter(torch.zeros(embed_size)) + self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) + self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) + self.class_pos = nn.Parameter(torch.zeros(embed_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hw_size, embed_size = self.spatial_pos.shape + pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) + pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) + pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) + class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) + return torch.cat((class_token, x), dim=1).add_(pos_embedding) + + +class MViTV2(nn.Module): + def __init__( + self, + spatial_size: Tuple[int, int], + temporal_size: int, + embed_channels: List[int], + blocks: List[int], + heads: List[int], + pool_kv_stride: List[int], + pool_q_stride: List[int], + pool_kvq_kernel: List[int], + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + num_classes: int = 400, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + """ + MViT V2 main class. + + Args: + spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. + temporal_size (int): The temporal size ``T`` of the input. + embed_channels (list of ints): A list with the embedding dimensions of each block group. + blocks (list of ints): A list with the number of blocks of each block group. + heads (list of ints): A list with the number of heads of each block group. + pool_kv_stride (list of ints): The initiale pooling stride of the first block. + pool_q_stride (list of ints): The pooling stride which reduces q in each block group. + pool_kvq_kernel (list of ints): The pooling kernel for the attention. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + num_classes (int): The number of classes. + block (callable, optional): Module specifying the layer which consists of the attention and mlp. + norm_layer (callable, optional): Module specifying the normalization layer to use. + """ + super().__init__() + # This implementation employs a different parameterization scheme than the one used at PyTorch Video: + # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py + # We remove any experimental configuration that didn't make it to the final variants of the models. To represent + # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper. + _log_api_usage_once(self) + num_blocks = len(blocks) + if num_blocks != len(embed_channels) or num_blocks != len(heads): + raise ValueError("The parameters 'embed_channels', 'blocks' and 'heads' must have equal length.") + + if block is None: + block = MultiscaleBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # Patch Embedding module + self.conv_proj = nn.Conv3d( + in_channels=3, + out_channels=embed_channels[0], + kernel_size=(3, 7, 7), + stride=(2, 4, 4), + padding=(1, 3, 3), + ) + + # Spatio-Temporal Class Positional Encoding + self.pos_encoding = PositionalEncoding( + embed_size=embed_channels[0], + spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), + temporal_size=temporal_size // self.conv_proj.stride[0], + ) + + # Encoder module + self.blocks = nn.ModuleList() + stage_block_id = 0 + pool_countdown = blocks[0] + input_channels = output_channels = embed_channels[0] + stride_kv = pool_kv_stride + total_stage_blocks = sum(blocks) + for i, num_subblocks in enumerate(blocks): + for j in range(num_subblocks): + next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i + output_channels = embed_channels[next_block_index] + + stride_q = [1, 1, 1] + if pool_countdown == 0: + stride_q = pool_q_stride + pool_countdown = blocks[next_block_index] + + stride_kv = [max(s // stride_q[d], 1) for d, s in enumerate(stride_kv)] + + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + + self.blocks.append( + block( + input_channels=input_channels, + output_channels=output_channels, + num_heads=heads[i], + kernel_q=pool_kvq_kernel, + kernel_kv=pool_kvq_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + input_channels = output_channels + stage_block_id += 1 + pool_countdown -= 1 + self.norm = norm_layer(output_channels) + + # Classifier module + layers: List[nn.Module] = [] + if dropout > 0.0: + layers.append(nn.Dropout(dropout, inplace=True)) + layers.append(nn.Linear(output_channels, num_classes)) + self.head = nn.Sequential(*layers) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.LayerNorm): + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, PositionalEncoding): + for weights in m.parameters(): + nn.init.trunc_normal_(weights, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) + x = self.conv_proj(x) + x = x.flatten(2).transpose(1, 2) + + # add positional encoding + x = self.pos_encoding(x) + + # pass patches through the encoder + thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size + for block in self.blocks: + x, thw = block(x, thw) + x = self.norm(x) + + # classifier "token" as used by standard language architectures + x = x[:, 0] + x = self.head(x) + + return x + + +def _mvitv2( + embed_channels: List[int], + blocks: List[int], + heads: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> MViTV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"][0]) + _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"]) + spatial_size = kwargs.pop("spatial_size", (224, 224)) + temporal_size = kwargs.pop("temporal_size", 16) + + model = MViTV2( + spatial_size=spatial_size, + temporal_size=temporal_size, + embed_channels=embed_channels, + blocks=blocks, + heads=heads, + pool_kv_stride=kwargs.pop("pool_kv_stride", [1, 8, 8]), + pool_q_stride=kwargs.pop("pool_q_stride", [1, 2, 2]), + pool_kvq_kernel=kwargs.pop("pool_kvq_kernel", [3, 3, 3]), + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +class MViT_V2_T_Weights(WeightsEnum): + pass + + +class MViT_V2_S_Weights(WeightsEnum): + pass + + +class MViT_V2_B_Weights(WeightsEnum): + pass + + +def mvit_v2_t(*, weights: Optional[MViT_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: + """ + Constructs a tiny MViTV2 architecture from + `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + `__ and `Multiscale Vision Transformers + `__. + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_T_Weights + :members: + """ + weights = MViT_V2_T_Weights.verify(weights) + + return _mvitv2( + spatial_size=(224, 224), + temporal_size=16, + embed_channels=[96, 192, 384, 768], + blocks=[1, 2, 5, 2], + heads=[1, 2, 4, 8], + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), + weights=weights, + progress=progress, + **kwargs, + ) + + +def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: + """ + Constructs a small MViTV2 architecture from + `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + `__ and `Multiscale Vision Transformers + `__. + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + :members: + """ + weights = MViT_V2_S_Weights.verify(weights) + + return _mvitv2( + spatial_size=(224, 224), + temporal_size=16, + embed_channels=[96, 192, 384, 768], + blocks=[1, 2, 11, 2], + heads=[1, 2, 4, 8], + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), + weights=weights, + progress=progress, + **kwargs, + ) + + +def mvit_v2_b(*, weights: Optional[MViT_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: + """ + Constructs a base MViTV2 architecture from + `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + `__ and `Multiscale Vision Transformers + `__. + + Args: + weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViT_V2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViT_V2_B_Weights + :members: + """ + weights = MViT_V2_B_Weights.verify(weights) + + return _mvitv2( + spatial_size=(224, 224), + temporal_size=32, + embed_channels=[96, 192, 384, 768], + blocks=[2, 3, 16, 3], + heads=[1, 2, 4, 8], + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.3), + weights=weights, + progress=progress, + **kwargs, + ) From 3e7683f8a462a9601b5744ef5c20493331c3c20e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 23 Jun 2022 15:10:52 +0100 Subject: [PATCH 2/7] Add support to MViT v1 (#6179) * Switch implementation to v1 variant. * Fix docs * Adding back a v2 pseudovariant * Changing the way the network are configured. * Temporarily removing v2 * Adding weights. * Expand _squeeze/_unsqueeze to support arbitrary dims. * Update references script. * Fix tests. * Fixing frames and preprocessing. * Fix std/mean values in transforms. * Add permanent Dropout and update the weights. * Update accuracies. --- docs/source/models.rst | 2 +- .../{video_mvitv2.rst => video_mvit.rst} | 10 +- references/video_classification/train.py | 5 +- ... => ModelTester.test_mvit_v1_b_expect.pkl} | Bin 939 -> 939 bytes .../ModelTester.test_mvit_v2_t_expect.pkl | Bin 939 -> 0 bytes test/test_extended_models.py | 4 +- test/test_models.py | 9 +- torchvision/models/video/__init__.py | 2 +- .../models/video/{mvitv2.py => mvit.py} | 494 ++++++++++-------- 9 files changed, 286 insertions(+), 240 deletions(-) rename docs/source/models/{video_mvitv2.rst => video_mvit.rst} (72%) rename test/expect/{ModelTester.test_mvit_v2_s_expect.pkl => ModelTester.test_mvit_v1_b_expect.pkl} (54%) delete mode 100644 test/expect/ModelTester.test_mvit_v2_t_expect.pkl rename torchvision/models/video/{mvitv2.py => mvit.py} (59%) diff --git a/docs/source/models.rst b/docs/source/models.rst index 0edb394f08f..769c2d2721b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -465,7 +465,7 @@ pre-trained weights: .. toctree:: :maxdepth: 1 - models/video_mvitv2 + models/video_mvit models/video_resnet | diff --git a/docs/source/models/video_mvitv2.rst b/docs/source/models/video_mvit.rst similarity index 72% rename from docs/source/models/video_mvitv2.rst rename to docs/source/models/video_mvit.rst index e9ad556ded7..713ca769f0b 100644 --- a/docs/source/models/video_mvitv2.rst +++ b/docs/source/models/video_mvit.rst @@ -12,17 +12,15 @@ The MViT V2 model is based on the Model builders -------------- -The following model builders can be used to instantiate a MViTV2 model, with or +The following model builders can be used to instantiate a MViT model, with or without pre-trained weights. All the model builders internally rely on the -``torchvision.models.video.MViTV2`` base class. Please refer to the `source +``torchvision.models.video.MViT`` base class. Please refer to the `source code -`_ for +`_ for more details about this class. .. autosummary:: :toctree: generated/ :template: function.rst - mvit_v2_t - mvit_v2_s - mvit_v2_b + mvit_v1_b diff --git a/references/video_classification/train.py b/references/video_classification/train.py index e1df08cbe4a..a746470be9b 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -152,7 +152,7 @@ def main(args): split="train", step_between_clips=1, transform=transform_train, - frame_rate=15, + frame_rate=args.frame_rate, extensions=( "avi", "mp4", @@ -189,7 +189,7 @@ def main(args): split="val", step_between_clips=1, transform=transform_test, - frame_rate=15, + frame_rate=args.frame_rate, extensions=( "avi", "mp4", @@ -324,6 +324,7 @@ def parse_args(): parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip") + parser.add_argument("--frame-rate", default=15, type=int, metavar="N", help="the frame rate") parser.add_argument( "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider" ) diff --git a/test/expect/ModelTester.test_mvit_v2_s_expect.pkl b/test/expect/ModelTester.test_mvit_v1_b_expect.pkl similarity index 54% rename from test/expect/ModelTester.test_mvit_v2_s_expect.pkl rename to test/expect/ModelTester.test_mvit_v1_b_expect.pkl index 48c342eec2d765c8c941da776543cb2f12df5d4c..cc6592c97bd17a81fa0281fd0013a8a2479e6c47 100644 GIT binary patch delta 230 zcmVDrG-p6~(?ra|r>VZ+f}_4cH9kIj;Uzx6_K&=y*M3nyYmhlULMq667FF$6rHlLjV8( delta 230 zcmV zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK630<(z3r5iKXw5nDSNkg`|tOO6S5UQG;`mh6C3Qgwkp`asd~Gw?Gu~5$V%@0 z)1L3zJ0pDezJEt<+x@X$yN84Iq8&qpo^{*vpL=aOH`%R;+_rC~v4q`}y$9@UPA;%B zIF)JNakO_oLzaP++ea(=yx6<8M<=xIXIXh@FRRlcJ9br}{jyvRb`BqV_d0m%*oty8 z?AP3?X|FYdbAQF&{reA{{bTz)nA5IrVUm6E62|>)wUzc2_jm5=Ygu5|5M;YYBYK~G z(fSK^e{}cl1%=j*hC4bZfFT9KxI>Gd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P int: @@ -36,18 +43,18 @@ def _prod(s: Sequence[int]) -> int: return product -def _unsqueeze(x: torch.Tensor) -> Tuple[torch.Tensor, int]: +def _unsqueeze(x: torch.Tensor, target_dim: int, expand_dim: int) -> Tuple[torch.Tensor, int]: tensor_dim = x.dim() - if tensor_dim == 3: - x = x.unsqueeze(1) - elif tensor_dim != 4: + if tensor_dim == target_dim - 1: + x = x.unsqueeze(expand_dim) + elif tensor_dim != target_dim: raise ValueError(f"Unsupported input dimension {x.shape}") return x, tensor_dim -def _squeeze(x: torch.Tensor, tensor_dim: int) -> torch.Tensor: - if tensor_dim == 3: - x = x.squeeze(1) +def _squeeze(x: torch.Tensor, target_dim: int, expand_dim: int, tensor_dim: int) -> torch.Tensor: + if tensor_dim == target_dim - 1: + x = x.squeeze(expand_dim) return x @@ -74,7 +81,7 @@ def __init__( self.norm_before_pool = norm_before_pool def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: - x, tensor_dim = _unsqueeze(x) + x, tensor_dim = _unsqueeze(x, 4, 1) # Separate the class token and reshape the input class_token, x = torch.tensor_split(x, indices=(1,), dim=2) @@ -95,7 +102,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten if not self.norm_before_pool and self.norm_act is not None: x = self.norm_act(x) - x = _squeeze(x, tensor_dim) + x = _squeeze(x, 4, 1, tensor_dim) return x, (T, H, W) @@ -108,6 +115,7 @@ def __init__( kernel_kv: List[int], stride_q: List[int], stride_kv: List[int], + residual_pool: bool, dropout: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: @@ -116,6 +124,7 @@ def __init__( self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaler = 1.0 / math.sqrt(self.head_dim) + self.residual_pool = residual_pool self.qkv = nn.Linear(embed_dim, 3 * embed_dim) layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] @@ -182,7 +191,9 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) attn = attn.softmax(dim=-1) - x = torch.matmul(attn, v).add_(q) + x = torch.matmul(attn, v) + if self.residual_pool: + x.add_(q) x = x.transpose(1, 2).reshape(B, -1, C) x = self.project(x) @@ -192,13 +203,8 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class MultiscaleBlock(nn.Module): def __init__( self, - input_channels: int, - output_channels: int, - num_heads: int, - kernel_q: List[int], - kernel_kv: List[int], - stride_q: List[int], - stride_kv: List[int], + cnf: MSBlockConfig, + residual_pool: bool, dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, @@ -206,30 +212,31 @@ def __init__( super().__init__() self.pool_skip: Optional[nn.Module] = None - if _prod(stride_q) > 1: - kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + if _prod(cnf.stride_q) > 1: + kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q] padding_skip = [int(k // 2) for k in kernel_skip] self.pool_skip = Pool( - nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None # type: ignore[arg-type] + nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip), None # type: ignore[arg-type] ) - self.norm1 = norm_layer(input_channels) - self.norm2 = norm_layer(input_channels) + self.norm1 = norm_layer(cnf.input_channels) + self.norm2 = norm_layer(cnf.input_channels) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) self.attn = MultiscaleAttention( - input_channels, - num_heads, - kernel_q=kernel_q, - kernel_kv=kernel_kv, - stride_q=stride_q, - stride_kv=stride_kv, + cnf.input_channels, + cnf.num_heads, + kernel_q=cnf.kernel_q, + kernel_kv=cnf.kernel_kv, + stride_q=cnf.stride_q, + stride_kv=cnf.stride_kv, + residual_pool=residual_pool, dropout=dropout, norm_layer=norm_layer, ) self.mlp = MLP( - input_channels, - [4 * input_channels, output_channels], + cnf.input_channels, + [4 * cnf.input_channels, cnf.output_channels], activation_layer=nn.GELU, dropout=dropout, inplace=None, @@ -238,8 +245,8 @@ def __init__( self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.project: Optional[nn.Module] = None - if input_channels != output_channels: - self.project = nn.Linear(input_channels, output_channels) + if cnf.input_channels != cnf.output_channels: + self.project = nn.Linear(cnf.input_channels, cnf.output_channels) def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] @@ -274,18 +281,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat((class_token, x), dim=1).add_(pos_embedding) -class MViTV2(nn.Module): +class MViT(nn.Module): def __init__( self, spatial_size: Tuple[int, int], temporal_size: int, - embed_channels: List[int], - blocks: List[int], - heads: List[int], - pool_kv_stride: List[int], - pool_q_stride: List[int], - pool_kvq_kernel: List[int], - dropout: float = 0.0, + block_setting: Sequence[MSBlockConfig], + residual_pool: bool, + dropout: float = 0.5, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 400, @@ -293,17 +296,13 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: """ - MViT V2 main class. + MViT main class. Args: spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. temporal_size (int): The temporal size ``T`` of the input. - embed_channels (list of ints): A list with the embedding dimensions of each block group. - blocks (list of ints): A list with the number of blocks of each block group. - heads (list of ints): A list with the number of heads of each block group. - pool_kv_stride (list of ints): The initiale pooling stride of the first block. - pool_q_stride (list of ints): The pooling stride which reduces q in each block group. - pool_kvq_kernel (list of ints): The pooling kernel for the attention. + block_setting (sequence of MSBlockConfig): The Network structure. + residual_pool (bool): If True, use MViTv2 pooling residual connection. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -317,9 +316,9 @@ def __init__( # We remove any experimental configuration that didn't make it to the final variants of the models. To represent # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper. _log_api_usage_once(self) - num_blocks = len(blocks) - if num_blocks != len(embed_channels) or num_blocks != len(heads): - raise ValueError("The parameters 'embed_channels', 'blocks' and 'heads' must have equal length.") + total_stage_blocks = len(block_setting) + if total_stage_blocks == 0: + raise ValueError("The configuration parameter can't be empty.") if block is None: block = MultiscaleBlock @@ -330,7 +329,7 @@ def __init__( # Patch Embedding module self.conv_proj = nn.Conv3d( in_channels=3, - out_channels=embed_channels[0], + out_channels=block_setting[0].input_channels, kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3), @@ -338,58 +337,33 @@ def __init__( # Spatio-Temporal Class Positional Encoding self.pos_encoding = PositionalEncoding( - embed_size=embed_channels[0], + embed_size=block_setting[0].input_channels, spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), temporal_size=temporal_size // self.conv_proj.stride[0], ) # Encoder module self.blocks = nn.ModuleList() - stage_block_id = 0 - pool_countdown = blocks[0] - input_channels = output_channels = embed_channels[0] - stride_kv = pool_kv_stride - total_stage_blocks = sum(blocks) - for i, num_subblocks in enumerate(blocks): - for j in range(num_subblocks): - next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i - output_channels = embed_channels[next_block_index] - - stride_q = [1, 1, 1] - if pool_countdown == 0: - stride_q = pool_q_stride - pool_countdown = blocks[next_block_index] - - stride_kv = [max(s // stride_q[d], 1) for d, s in enumerate(stride_kv)] - - # adjust stochastic depth probability based on the depth of the stage block - sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - - self.blocks.append( - block( - input_channels=input_channels, - output_channels=output_channels, - num_heads=heads[i], - kernel_q=pool_kvq_kernel, - kernel_kv=pool_kvq_kernel, - stride_q=stride_q, - stride_kv=stride_kv, - dropout=attention_dropout, - stochastic_depth_prob=sd_prob, - norm_layer=norm_layer, - ) + for stage_block_id, cnf in enumerate(block_setting): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + + self.blocks.append( + block( + cnf=cnf, + residual_pool=residual_pool, + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, ) - input_channels = output_channels - stage_block_id += 1 - pool_countdown -= 1 - self.norm = norm_layer(output_channels) + ) + self.norm = norm_layer(block_setting[-1].output_channels) # Classifier module - layers: List[nn.Module] = [] - if dropout > 0.0: - layers.append(nn.Dropout(dropout, inplace=True)) - layers.append(nn.Linear(output_channels, num_classes)) - self.head = nn.Sequential(*layers) + self.head = nn.Sequential( + nn.Dropout(dropout, inplace=True), + nn.Linear(block_setting[-1].output_channels, num_classes), + ) for m in self.modules(): if isinstance(m, nn.Linear): @@ -426,32 +400,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _mvitv2( - embed_channels: List[int], - blocks: List[int], - heads: List[int], +def _mvit( + block_setting: List[MSBlockConfig], stochastic_depth_prob: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -) -> MViTV2: +) -> MViT: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] - _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"][0]) + _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"]) _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"]) spatial_size = kwargs.pop("spatial_size", (224, 224)) temporal_size = kwargs.pop("temporal_size", 16) - model = MViTV2( + model = MViT( spatial_size=spatial_size, temporal_size=temporal_size, - embed_channels=embed_channels, - blocks=blocks, - heads=heads, - pool_kv_stride=kwargs.pop("pool_kv_stride", [1, 8, 8]), - pool_q_stride=kwargs.pop("pool_q_stride", [1, 2, 2]), - pool_kvq_kernel=kwargs.pop("pool_kvq_kernel", [3, 3, 3]), + block_setting=block_setting, + residual_pool=kwargs.pop("residual_pool", False), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) @@ -462,126 +430,210 @@ def _mvitv2( return model -class MViT_V2_T_Weights(WeightsEnum): - pass - - -class MViT_V2_S_Weights(WeightsEnum): - pass - - -class MViT_V2_B_Weights(WeightsEnum): - pass - - -def mvit_v2_t(*, weights: Optional[MViT_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: - """ - Constructs a tiny MViTV2 architecture from - `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__ and `Multiscale Vision Transformers - `__. - - Args: - weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.video.MViT_V2_T_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.video.MViT_V2_T_Weights - :members: - """ - weights = MViT_V2_T_Weights.verify(weights) - - return _mvitv2( - spatial_size=(224, 224), - temporal_size=16, - embed_channels=[96, 192, 384, 768], - blocks=[1, 2, 5, 2], - heads=[1, 2, 4, 8], - stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), - weights=weights, - progress=progress, - **kwargs, +class MViT_V1_B_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mvit_v1_b-dbeb1030.pth", + transforms=partial( + VideoClassification, + crop_size=(224, 224), + resize_size=(256,), + mean=(0.45, 0.45, 0.45), + std=(0.225, 0.225, 0.225), + ), + meta={ + "min_size": (224, 224), + "min_temporal_size": 16, + "categories": _KINETICS400_CATEGORIES, + "recipe": "https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md", + "_docs": """These weights support 16-frame clip inputs and were ported from the paper.""", + "num_params": 36610672, + "_metrics": { + "Kinetics-400": { + "acc@1": 78.47, + "acc@5": 93.65, + } + }, + }, ) + DEFAULT = KINETICS400_V1 -def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: +def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: """ - Constructs a small MViTV2 architecture from - `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__ and `Multiscale Vision Transformers - `__. + Constructs a base MViTV1 architecture from + `Multiscale Vision Transformers `__. Args: - weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The + weights (:class:`~torchvision.models.video.MViT_V1_B_Weights`, optional): The pretrained weights to use. See - :class:`~torchvision.models.video.MViT_V2_S_Weights` below for + :class:`~torchvision.models.video.MViT_V1_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + **kwargs: parameters passed to the ``torchvision.models.video.MViT`` base class. Please refer to the `source code - `_ + `_ for more details about this class. - .. autoclass:: torchvision.models.video.MViT_V2_S_Weights + .. autoclass:: torchvision.models.video.MViT_V1_B_Weights :members: """ - weights = MViT_V2_S_Weights.verify(weights) - - return _mvitv2( + weights = MViT_V1_B_Weights.verify(weights) + + block_setting = [ + MSBlockConfig( + num_heads=1, + input_channels=96, + output_channels=192, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 8, 8], + ), + MSBlockConfig( + num_heads=2, + input_channels=192, + output_channels=192, + kernel_q=[3, 3, 3], + kernel_kv=[3, 3, 3], + stride_q=[1, 2, 2], + stride_kv=[1, 4, 4], + ), + MSBlockConfig( + num_heads=2, + input_channels=192, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 4, 4], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[3, 3, 3], + kernel_kv=[3, 3, 3], + stride_q=[1, 2, 2], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=384, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=4, + input_channels=384, + output_channels=768, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 2, 2], + ), + MSBlockConfig( + num_heads=8, + input_channels=768, + output_channels=768, + kernel_q=[3, 3, 3], + kernel_kv=[3, 3, 3], + stride_q=[1, 2, 2], + stride_kv=[1, 1, 1], + ), + MSBlockConfig( + num_heads=8, + input_channels=768, + output_channels=768, + kernel_q=[], + kernel_kv=[3, 3, 3], + stride_q=[], + stride_kv=[1, 1, 1], + ), + ] + + return _mvit( spatial_size=(224, 224), temporal_size=16, - embed_channels=[96, 192, 384, 768], - blocks=[1, 2, 11, 2], - heads=[1, 2, 4, 8], - stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), - weights=weights, - progress=progress, - **kwargs, - ) - - -def mvit_v2_b(*, weights: Optional[MViT_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: - """ - Constructs a base MViTV2 architecture from - `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__ and `Multiscale Vision Transformers - `__. - - Args: - weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The - pretrained weights to use. See - :class:`~torchvision.models.video.MViT_V2_B_Weights` below for - more details, and possible values. By default, no pre-trained - weights are used. - progress (bool, optional): If True, displays a progress bar of the - download to stderr. Default is True. - **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` - base class. Please refer to the `source code - `_ - for more details about this class. - - .. autoclass:: torchvision.models.video.MViT_V2_B_Weights - :members: - """ - weights = MViT_V2_B_Weights.verify(weights) - - return _mvitv2( - spatial_size=(224, 224), - temporal_size=32, - embed_channels=[96, 192, 384, 768], - blocks=[2, 3, 16, 3], - heads=[1, 2, 4, 8], - stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.3), + block_setting=block_setting, + residual_pool=False, + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.2), weights=weights, progress=progress, **kwargs, From 6a9edef290576bebdb83a7e73b24858e92b9addb Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 23 Jun 2022 15:39:04 +0100 Subject: [PATCH 3/7] Fix documentation --- docs/source/models/video_mvit.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/models/video_mvit.rst b/docs/source/models/video_mvit.rst index 713ca769f0b..d5be1245ac9 100644 --- a/docs/source/models/video_mvit.rst +++ b/docs/source/models/video_mvit.rst @@ -1,9 +1,9 @@ -Video ResNet -============ +Video MViT +========== .. currentmodule:: torchvision.models.video -The MViT V2 model is based on the +The MViT model is based on the `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection `__ and `Multiscale Vision Transformers `__ papers. From 23a95a7b1e69f3661794aa8f68027cc5c2ad5cf9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 23 Jun 2022 15:41:34 +0100 Subject: [PATCH 4/7] Remove unnecessary expected file. --- test/expect/ModelTester.test_mvit_v2_b_expect.pkl | Bin 939 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/expect/ModelTester.test_mvit_v2_b_expect.pkl diff --git a/test/expect/ModelTester.test_mvit_v2_b_expect.pkl b/test/expect/ModelTester.test_mvit_v2_b_expect.pkl deleted file mode 100644 index a39dd2e5754797feaba9c8e10f103900e77a6c98..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~q{x-o9I|2lkzD+hTX@zM=hbeMS2krpWykA*=1f#ZC6Fv3P1Xt$D^ihs=E2 zW`_-Sf{`oj42s-r0v0Ijy_7R`-^Svr`@%L}-M>0=;r_f|{`;4hEV36gkliQjX|(^- z{}j9TRebi-zGm&4aZJqa|4UW-CAa$a1xp;-7hbQl{{VOXo&qC>yGd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P Date: Thu, 23 Jun 2022 17:08:16 +0100 Subject: [PATCH 5/7] Skip big model test --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index 63e4870801f..1f4973338ad 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -833,7 +833,7 @@ def test_video_model(model_fn, dev): "num_classes": 50, } model_name = model_fn.__name__ - if dev == "cuda" and SKIP_BIG_MODEL and model_name in skipped_big_models: + if SKIP_BIG_MODEL and model_name in skipped_big_models: pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") kwargs = {**defaults, **_model_params.get(model_name, {})} num_classes = kwargs.get("num_classes") From 5da567a62e5d0da024d900032c5ebcda834c3327 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 24 Jun 2022 13:39:20 +0100 Subject: [PATCH 6/7] Rewrite the configuration logic to reduce LOC. --- torchvision/models/video/mvit.py | 203 +++++++++---------------------- 1 file changed, 57 insertions(+), 146 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index a3b543cf8ae..36855d21ab7 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -481,152 +481,63 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T """ weights = MViT_V1_B_Weights.verify(weights) - block_setting = [ - MSBlockConfig( - num_heads=1, - input_channels=96, - output_channels=192, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 8, 8], - ), - MSBlockConfig( - num_heads=2, - input_channels=192, - output_channels=192, - kernel_q=[3, 3, 3], - kernel_kv=[3, 3, 3], - stride_q=[1, 2, 2], - stride_kv=[1, 4, 4], - ), - MSBlockConfig( - num_heads=2, - input_channels=192, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 4, 4], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[3, 3, 3], - kernel_kv=[3, 3, 3], - stride_q=[1, 2, 2], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=384, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=4, - input_channels=384, - output_channels=768, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 2, 2], - ), - MSBlockConfig( - num_heads=8, - input_channels=768, - output_channels=768, - kernel_q=[3, 3, 3], - kernel_kv=[3, 3, 3], - stride_q=[1, 2, 2], - stride_kv=[1, 1, 1], - ), - MSBlockConfig( - num_heads=8, - input_channels=768, - output_channels=768, - kernel_q=[], - kernel_kv=[3, 3, 3], - stride_q=[], - stride_kv=[1, 1, 1], - ), - ] + config = { + "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], + "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], + "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768], + "kernel_q": [[], [3, 3, 3], [], [3, 3, 3], [], [], [], [], [], [], [], [], [], [], [3, 3, 3], []], + "kernel_kv": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + ], + "stride_q": [[], [1, 2, 2], [], [1, 2, 2], [], [], [], [], [], [], [], [], [], [], [1, 2, 2], []], + "stride_kv": [ + [1, 8, 8], + [1, 4, 4], + [1, 4, 4], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1], + [1, 1, 1], + ], + } + + block_setting = [] + for i in range(len(config["num_heads"])): + block_setting.append( + MSBlockConfig( + num_heads=config["num_heads"][i], + input_channels=config["input_channels"][i], + output_channels=config["output_channels"][i], + kernel_q=config["kernel_q"][i], + kernel_kv=config["kernel_kv"][i], + stride_q=config["stride_q"][i], + stride_kv=config["stride_kv"][i], + ) + ) return _mvit( spatial_size=(224, 224), From c358a9707e8141174db91b58d122eb99e98e9169 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 24 Jun 2022 13:48:50 +0100 Subject: [PATCH 7/7] Fix mypy --- torchvision/models/video/mvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 36855d21ab7..d8bfc0dbb77 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -1,7 +1,7 @@ import math from dataclasses import dataclass from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch import torch.fx @@ -481,7 +481,7 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T """ weights = MViT_V1_B_Weights.verify(weights) - config = { + config: Dict[str, List] = { "num_heads": [1, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 8], "input_channels": [96, 192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768], "output_channels": [192, 192, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 768, 768, 768],