From 2d3a0afa7a61d980e772d05062d30045d13f0d88 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 May 2022 10:20:55 +0100 Subject: [PATCH 01/11] Adding mvitv2 architecture --- .../ModelTester.test_mvitv2_b_expect.pkl | Bin 0 -> 939 bytes .../ModelTester.test_mvitv2_s_expect.pkl | Bin 0 -> 939 bytes .../ModelTester.test_mvitv2_t_expect.pkl | Bin 0 -> 939 bytes test/test_models.py | 12 + torchvision/models/video/__init__.py | 1 + torchvision/models/video/mvit.py | 485 ++++++++++++++++++ 6 files changed, 498 insertions(+) create mode 100644 test/expect/ModelTester.test_mvitv2_b_expect.pkl create mode 100644 test/expect/ModelTester.test_mvitv2_s_expect.pkl create mode 100644 test/expect/ModelTester.test_mvitv2_t_expect.pkl create mode 100644 torchvision/models/video/mvit.py diff --git a/test/expect/ModelTester.test_mvitv2_b_expect.pkl b/test/expect/ModelTester.test_mvitv2_b_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a39dd2e5754797feaba9c8e10f103900e77a6c98 GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~q{x-o9I|2lkzD+hTX@zM=hbeMS2krpWykA*=1f#ZC6Fv3P1Xt$D^ihs=E2 zW`_-Sf{`oj42s-r0v0Ijy_7R`-^Svr`@%L}-M>0=;r_f|{`;4hEV36gkliQjX|(^- z{}j9TRebi-zGm&4aZJqa|4UW-CAa$a1xp;-7hbQl{{VOXo&qC>yGd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=U>o(!Q?udi&G=JhWAO>uf(|x4C_RL5lqwr#M^7!iD>e*~i%}|Lr5t{wZNh`=4ETzVD5R;(ocL$#&@lHxrBfrgKjF7X$>`FDPo*)9`=wZV%Q4doC~@ z-FqZ*mE9(7r~S2w#`dv4U)$9rKi*sL_03)z7r%XC{}=CV^XJ#89K2T`Q(|Et_1TdsP7;2j35f0CXwS%03?9|&{HV7Ze&04q3C=C?(c~y%Ind!t_GJ zAi$fAO$Vw-j#(G39F&+r07h?za2Y0nJqhwI8z^ructRC`GC_bhD;r3R83;k@A!-5R CO8RjC literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_mvitv2_t_expect.pkl b/test/expect/ModelTester.test_mvitv2_t_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..384fe05b50c9fc4fcc92c39f62d6dc299282ea9f GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK630<(z3r5iKXw5nDSNkg`|tOO6S5UQG;`mh6C3Qgwkp`asd~Gw?Gu~5$V%@0 z)1L3zJ0pDezJEt<+x@X$yN84Iq8&qpo^{*vpL=aOH`%R;+_rC~v4q`}y$9@UPA;%B zIF)JNakO_oLzaP++ea(=yx6<8M<=xIXIXh@FRRlcJ9br}{jyvRb`BqV_d0m%*oty8 z?AP3?X|FYdbAQF&{reA{{bTz)nA5IrVUm6E62|>)wUzc2_jm5=Ygu5|5M;YYBYK~G z(fSK^e{}cl1%=j*hC4bZfFT9KxI>Gd!5SV~WvNBQz*ul|GAA;)kU|c^H0A=?d~sfS zC=<|D5DxHW1X1ubi5!OlAPE$JooXBrWYCp z0p4tEI#5M&%(`&ppu`LUFnT+L%P int: + product = 1 + for v in s: + product *= v + return product + + +def _unsqueeze(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + tensor_dim = x.dim() + if tensor_dim == 3: + x = x.unsqueeze(1) + elif tensor_dim != 4: + raise NotImplementedError(f"Unsupported input dimension {x.shape}") + return x, tensor_dim + + +def _squeeze(x: torch.Tensor, tensor_dim: int) -> torch.Tensor: + if tensor_dim == 3: + x = x.squeeze(1) + return x + + +torch.fx.wrap("_unsqueeze") +torch.fx.wrap("_squeeze") + + +class Pool(nn.Module): + def __init__( + self, + pool: nn.Module, + norm: Optional[nn.Module], + activation: Optional[nn.Module] = None, + norm_before_pool: bool = False, + ) -> None: + super().__init__() + self.pool = pool + layers = [] + if norm is not None: + layers.append(norm) + if activation is not None: + layers.append(activation) + self.norm_act = nn.Sequential(*layers) if layers else None + self.norm_before_pool = norm_before_pool + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x, tensor_dim = _unsqueeze(x) + + # Separate the class token and reshape the input + class_token, x = torch.tensor_split(x, indices=(1,), dim=2) + x = x.transpose(2, 3) + B, N, C = x.shape[:3] + x = x.reshape((B * N, C) + thw) + + # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference + if self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + # apply the pool on the input and add back the token + x = self.pool(x) + T, H, W = x.shape[2:] + x = x.reshape(B, N, C, -1).transpose(2, 3) + x = torch.cat((class_token, x), dim=2) + + if not self.norm_before_pool and self.norm_act is not None: + x = self.norm_act(x) + + x = _squeeze(x, tensor_dim) + return x, (T, H, W) + + +class MultiscaleAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + kernel_q: Tuple[int, int, int] = (1, 1, 1), + kernel_kv: Tuple[int, int, int] = (1, 1, 1), + stride_q: Tuple[int, int, int] = (1, 1, 1), + stride_kv: Tuple[int, int, int] = (1, 1, 1), + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scaler = 1.0 / math.sqrt(self.head_dim) + + self.qkv = nn.Linear(embed_dim, 3 * embed_dim) + layers: List[nn.Module] = [nn.Linear(embed_dim, embed_dim)] + if dropout > 0.0: + layers.append(nn.Dropout(dropout, inplace=True)) + self.project = nn.Sequential(*layers) + + self.pool_q: Optional[nn.Module] = None + if _prod(kernel_q) > 1 or _prod(stride_q) > 1: + padding_q = cast(Tuple[int, int, int], tuple(int(q // 2) for q in kernel_q)) + self.pool_q = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + self.pool_k: Optional[nn.Module] = None + self.pool_v: Optional[nn.Module] = None + if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1: + padding_kv = cast(Tuple[int, int, int], tuple(int(kv // 2) for kv in kernel_kv)) + self.pool_k = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + self.pool_v = Pool( + nn.Conv3d( + self.head_dim, + self.head_dim, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=self.head_dim, + bias=False, + ), + norm_layer(self.head_dim), + ) + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + B, N, C = x.shape + q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(dim=2) + + if self.pool_k is not None: + k = self.pool_k(k, thw)[0] + if self.pool_v is not None: + v = self.pool_v(v, thw)[0] + if self.pool_q is not None: + q, thw = self.pool_q(q, thw) + + attn = torch.matmul(self.scaler * q, k.transpose(2, 3)) + attn = attn.softmax(dim=-1) + + x = torch.matmul(attn, v).add_(q) + x = x.transpose(1, 2).reshape(B, -1, C) + x = self.project(x) + + return x, thw + + +class MultiscaleBlock(nn.Module): + def __init__( + self, + input_channels: int, + output_channels: int, + num_heads: int, + dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + kernel_q: Tuple[int, int, int] = (1, 1, 1), + kernel_kv: Tuple[int, int, int] = (1, 1, 1), + stride_q: Tuple[int, int, int] = (1, 1, 1), + stride_kv: Tuple[int, int, int] = (1, 1, 1), + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(input_channels) + self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) + self.attn = MultiscaleAttention( + input_channels, + num_heads, + dropout=dropout, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + ) + self.norm2 = norm_layer(input_channels) + self.mlp = MLP( + input_channels, + [4 * input_channels, output_channels], + activation_layer=nn.GELU, + dropout=dropout, + inplace=None, + ) + + self.project: Optional[nn.Module] = None + if input_channels != output_channels: + self.project = nn.Linear(input_channels, output_channels) + + self.pool_skip: Optional[nn.Module] = None + if _prod(stride_q) > 1: + kernel_skip = cast(Tuple[int, int, int], tuple(s + 1 if s > 1 else s for s in stride_q)) + padding_skip = cast(Tuple[int, int, int], tuple(int(k // 2) for k in kernel_skip)) + self.pool_skip = Pool(nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None) + + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: + x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] + + x = self.norm1(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm1(x) + x, thw = self.attn(x, thw) + x = x_skip + self.stochastic_depth(x) + + x_norm = self.norm2(x.transpose(1, 2)).transpose(1, 2) if self.needs_transposal else self.norm2(x) + x_proj = x if self.project is None else self.project(x_norm) + + return x_proj + self.stochastic_depth(self.mlp(x_norm)), thw + + +class PositionalEncoding(nn.Module): + def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int) -> None: + super().__init__() + self.spatial_size = spatial_size + self.temporal_size = temporal_size + + self.class_token = nn.Parameter(torch.zeros(embed_size)) + self.spatial_pos = nn.Parameter(torch.zeros(self.spatial_size[0] * self.spatial_size[1], embed_size)) + self.temporal_pos = nn.Parameter(torch.zeros(self.temporal_size, embed_size)) + self.class_pos = nn.Parameter(torch.zeros(embed_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hw_size, embed_size = self.spatial_pos.shape + pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0) + pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) + pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) + class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) + return torch.cat((class_token, x), dim=1) + pos_embedding + + +class MultiscaleVisionTransformer(nn.Module): + def __init__( + self, + spatial_size: Tuple[int, int], + temporal_size: int, + embed_channels: List[int], + blocks: List[int], + heads: List[int], + pool_kv_stride: Tuple[int, int, int] = (1, 8, 8), + pool_q_stride: Tuple[int, int, int] = (1, 2, 2), + pool_kvq_kernel: Tuple[int, int, int] = (3, 3, 3), + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + num_classes: int = 400, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + + if block is None: + block = MultiscaleBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + # Patch Embedding module + self.conv_proj = nn.Conv3d( + in_channels=3, + out_channels=embed_channels[0], + kernel_size=(3, 7, 7), + stride=(2, 4, 4), + padding=(1, 3, 3), + ) + + # Spatio-Temporal Class Positional Encoding + self.pos_encoding = PositionalEncoding( + embed_size=embed_channels[0], + spatial_size=(spatial_size[0] // self.conv_proj.stride[1], spatial_size[1] // self.conv_proj.stride[2]), + temporal_size=temporal_size // self.conv_proj.stride[0], + ) + + # Encoder module + self.blocks = nn.ModuleList() + stage_block_id = 0 + pool_countdown = blocks[0] + input_channels = output_channels = embed_channels[0] + stride_kv = pool_kv_stride + total_stage_blocks = sum(blocks) + for i, num_subblocks in enumerate(blocks): + for j in range(num_subblocks): + next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < len(embed_channels) else i + output_channels = embed_channels[next_block_index] + + stride_q = (1, 1, 1) + if pool_countdown == 0: + stride_q = pool_q_stride + pool_countdown = blocks[next_block_index] + + stride_kv = cast(Tuple[int, int, int], tuple(max(s // stride_q[d], 1) for d, s in enumerate(stride_kv))) + + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + + self.blocks.append( + block( + input_channels=input_channels, + output_channels=output_channels, + num_heads=heads[i], + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + kernel_q=pool_kvq_kernel, + kernel_kv=pool_kvq_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + ) + ) + input_channels = output_channels + stage_block_id += 1 + pool_countdown -= 1 + self.norm = norm_layer(output_channels) + + # Classifier module + layers: List[nn.Module] = [] + if dropout > 0.0: + layers.append(nn.Dropout(dropout, inplace=True)) + layers.append(nn.Linear(output_channels, num_classes)) + self.head = nn.Sequential(*layers) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, PositionalEncoding): + for weights in m.parameters(): + nn.init.trunc_normal_(weights, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # patchify and reshape: (B, C, T, H, W) -> (B, patch_embed_dim, T', H', W') -> (B, THW', patch_embed_dim) + x = self.conv_proj(x) + x = x.flatten(2).transpose(1, 2) + + # add positional encoding + x = self.pos_encoding(x) + + # pass patches through the encoder + thw = (self.pos_encoding.temporal_size,) + self.pos_encoding.spatial_size + for block in self.blocks: + x, thw = block(x, thw) + x = self.norm(x) + + # classifier "token" as used by standard language architectures + x = x[:, 0] + x = self.head(x) + + return x + + +def _mvitv2( + temporal_size: int, + embed_channels: List[int], + blocks: List[int], + heads: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> MultiscaleVisionTransformer: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"][0]) + # TODO: add min_temporal_size in the meta-data? + spatial_size = kwargs.pop("spatial_size", (224, 224)) + + model = MultiscaleVisionTransformer( + spatial_size=spatial_size, + temporal_size=temporal_size, + embed_channels=embed_channels, + blocks=blocks, + heads=heads, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +class MViTV2_T_Weights(WeightsEnum): + pass + + +class MViTV2_S_Weights(WeightsEnum): + pass + + +class MViTV2_B_Weights(WeightsEnum): + pass + + +def mvitv2_t( + *, weights: Optional[MViTV2_T_Weights] = None, progress: bool = True, **kwargs: Any +) -> MultiscaleVisionTransformer: + weights = MViTV2_T_Weights.verify(weights) + + return _mvitv2( + spatial_size=(224, 224), + temporal_size=16, + embed_channels=[96, 192, 384, 768], + blocks=[1, 2, 5, 2], + heads=[1, 2, 4, 8], + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), + weights=weights, + progress=progress, + **kwargs, + ) + + +def mvitv2_s( + *, weights: Optional[MViTV2_S_Weights] = None, progress: bool = True, **kwargs: Any +) -> MultiscaleVisionTransformer: + weights = MViTV2_S_Weights.verify(weights) + + return _mvitv2( + spatial_size=(224, 224), + temporal_size=16, + embed_channels=[96, 192, 384, 768], + blocks=[1, 2, 11, 2], + heads=[1, 2, 4, 8], + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.1), + weights=weights, + progress=progress, + **kwargs, + ) + + +def mvitv2_b( + *, weights: Optional[MViTV2_B_Weights] = None, progress: bool = True, **kwargs: Any +) -> MultiscaleVisionTransformer: + weights = MViTV2_B_Weights.verify(weights) + + return _mvitv2( + spatial_size=(224, 224), + temporal_size=32, + embed_channels=[96, 192, 384, 768], + blocks=[2, 3, 16, 3], + heads=[1, 2, 4, 8], + stochastic_depth_prob=kwargs.pop("stochastic_depth_prob", 0.3), + weights=weights, + progress=progress, + **kwargs, + ) From ec8339f32407df2e9817d14665db979b3f1d60ce Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 May 2022 19:15:27 +0100 Subject: [PATCH 02/11] Fixing memory issues on tests and minor refactorings. --- test/test_models.py | 6 +++--- torchvision/models/video/mvit.py | 23 ++++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index e8f13f29bc3..6e586049008 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -318,9 +318,6 @@ def _check_input_backprop(model, inputs): "mvitv2_b": { "input_shape": (1, 3, 32, 224, 224), }, - "mvitv2_l": { - "input_shape": (1, 3, 40, 312, 312), - }, } # speeding up slow models: slow_models = [ @@ -350,6 +347,7 @@ def _check_input_backprop(model, inputs): skipped_big_models = { "vit_h_14", "regnet_y_128gf", + "mvitv2_b", } # The following contains configuration and expected values to be used tests that are model specific @@ -842,6 +840,8 @@ def test_video_model(model_fn, dev): "num_classes": 50, } model_name = model_fn.__name__ + if dev == "cuda" and SKIP_BIG_MODEL and model_name in skipped_big_models: + pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") kwargs = {**defaults, **_model_params.get(model_name, {})} num_classes = kwargs.get("num_classes") input_shape = kwargs.pop("input_shape") diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index ca633e152bf..96aff3e6a59 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -195,8 +195,16 @@ def __init__( ) -> None: super().__init__() + self.pool_skip: Optional[nn.Module] = None + if _prod(stride_q) > 1: + kernel_skip = cast(Tuple[int, int, int], tuple(s + 1 if s > 1 else s for s in stride_q)) + padding_skip = cast(Tuple[int, int, int], tuple(int(k // 2) for k in kernel_skip)) + self.pool_skip = Pool(nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None) + self.norm1 = norm_layer(input_channels) + self.norm2 = norm_layer(input_channels) self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d) + self.attn = MultiscaleAttention( input_channels, num_heads, @@ -207,7 +215,6 @@ def __init__( stride_kv=stride_kv, norm_layer=norm_layer, ) - self.norm2 = norm_layer(input_channels) self.mlp = MLP( input_channels, [4 * input_channels, output_channels], @@ -216,18 +223,12 @@ def __init__( inplace=None, ) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.project: Optional[nn.Module] = None if input_channels != output_channels: self.project = nn.Linear(input_channels, output_channels) - self.pool_skip: Optional[nn.Module] = None - if _prod(stride_q) > 1: - kernel_skip = cast(Tuple[int, int, int], tuple(s + 1 if s > 1 else s for s in stride_q)) - padding_skip = cast(Tuple[int, int, int], tuple(int(k // 2) for k in kernel_skip)) - self.pool_skip = Pool(nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None) - - self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") - def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: x_skip = x if self.pool_skip is None else self.pool_skip(x, thw)[0] @@ -355,12 +356,12 @@ def __init__( if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) + nn.init.constant_(m.bias, 0.0) elif isinstance(m, nn.LayerNorm): if m.weight is not None: nn.init.constant_(m.weight, 1.0) if m.bias is not None: - nn.init.constant_(m.bias, 0) + nn.init.constant_(m.bias, 0.0) elif isinstance(m, PositionalEncoding): for weights in m.parameters(): nn.init.trunc_normal_(weights, std=0.02) From 2699b52534c36b0d4b7281ce387f53097035961a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 28 May 2022 09:19:38 +0100 Subject: [PATCH 03/11] Adding input validation --- torchvision/models/video/mvit.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvit.py index 96aff3e6a59..df73e0d1a10 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvit.py @@ -31,7 +31,7 @@ def _unsqueeze(x: torch.Tensor) -> Tuple[torch.Tensor, int]: if tensor_dim == 3: x = x.unsqueeze(1) elif tensor_dim != 4: - raise NotImplementedError(f"Unsupported input dimension {x.shape}") + raise ValueError(f"Unsupported input dimension {x.shape}") return x, tensor_dim @@ -277,10 +277,13 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 400, - block: Optional[Callable[..., nn.Module]] = None, + block: Callable[..., nn.Module] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() + num_blocks = len(blocks) + if num_blocks != len(embed_channels) or num_blocks != len(heads): + raise ValueError("The parameters 'embed_channels', 'blocks' and 'heads' must have equal length.") if block is None: block = MultiscaleBlock @@ -304,7 +307,7 @@ def __init__( temporal_size=temporal_size // self.conv_proj.stride[0], ) - # Encoder module + # Encoder modules self.blocks = nn.ModuleList() stage_block_id = 0 pool_countdown = blocks[0] @@ -313,7 +316,7 @@ def __init__( total_stage_blocks = sum(blocks) for i, num_subblocks in enumerate(blocks): for j in range(num_subblocks): - next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < len(embed_channels) else i + next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i output_channels = embed_channels[next_block_index] stride_q = (1, 1, 1) @@ -367,7 +370,7 @@ def __init__( nn.init.trunc_normal_(weights, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: - # patchify and reshape: (B, C, T, H, W) -> (B, patch_embed_dim, T', H', W') -> (B, THW', patch_embed_dim) + # patchify and reshape: (B, C, T, H, W) -> (B, embed_channels[0], T', H', W') -> (B, THW', embed_channels[0]) x = self.conv_proj(x) x = x.flatten(2).transpose(1, 2) From f62bd8389b5637d98b4bb0178642275e206bcf94 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 May 2022 09:42:50 +0100 Subject: [PATCH 04/11] Adding docs and minor refactoring --- torchvision/models/video/__init__.py | 2 +- .../models/video/{mvit.py => mvitv2.py} | 113 +++++++++++++++--- 2 files changed, 98 insertions(+), 17 deletions(-) rename torchvision/models/video/{mvit.py => mvitv2.py} (75%) diff --git a/torchvision/models/video/__init__.py b/torchvision/models/video/__init__.py index 8990f64a1dc..393a26ccbbe 100644 --- a/torchvision/models/video/__init__.py +++ b/torchvision/models/video/__init__.py @@ -1,2 +1,2 @@ -from .mvit import * +from .mvitv2 import * from .resnet import * diff --git a/torchvision/models/video/mvit.py b/torchvision/models/video/mvitv2.py similarity index 75% rename from torchvision/models/video/mvit.py rename to torchvision/models/video/mvitv2.py index df73e0d1a10..4bbc828d5aa 100644 --- a/torchvision/models/video/mvit.py +++ b/torchvision/models/video/mvitv2.py @@ -7,14 +7,14 @@ import torch.nn as nn from ...ops import StochasticDepth, MLP +from ...utils import _log_api_usage_once from .._api import WeightsEnum from .._utils import _ovewrite_named_param -__all__ = ["mvitv2_t", "mvitv2_s", "mvitv2_b", "MViTV2_T_Weights", "MViTV2_S_Weights", "MViTV2_B_Weights"] +__all__ = ["MViTv2", "MViTV2_T_Weights", "MViTV2_S_Weights", "MViTV2_B_Weights", "mvitv2_t", "mvitv2_s", "mvitv2_b"] -# TODO: add docs # TODO: add weights # TODO: test on references @@ -262,7 +262,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat((class_token, x), dim=1) + pos_embedding -class MultiscaleVisionTransformer(nn.Module): +class MViTv2(nn.Module): def __init__( self, spatial_size: Tuple[int, int], @@ -277,10 +277,34 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, num_classes: int = 400, - block: Callable[..., nn.Module] = None, + block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: + """ + MViTv2 main class. + + Args: + spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. + temporal_size (int): The temporal size ``T`` of the input. + embed_channels (list of ints): A list with the embedding dimensions of each block group. + blocks (list of ints): A list with the number of blocks of each block group. + heads (list of ints): A list with the number of heads of each block group. + pool_kv_stride (tuple of ints): The initialize pooling stride of the first block. + pool_q_stride (tuple of ints): The pooling stride which reduces q in each block group. + pool_kvq_kernel (tuple of ints): The pooling kernel for the attention. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + num_classes (int): The number of classes. + block (callable, optional): Module specifying the layer which consists of the attention and mlp. + norm_layer (callable, optional): Module specifying the normalization layer to use. + """ super().__init__() + # This implementation employs a different parameterization scheme than the one used at PyTorch Video: + # https://github.com/facebookresearch/pytorchvideo/blob/718d0a4/pytorchvideo/models/vision_transformers.py + # We remove any experimental configuration that didn't make it to the final variants of the models. To represent + # the configuration of the architecture we use the simplified form suggested at Table 1 of the paper. + _log_api_usage_once(self) num_blocks = len(blocks) if num_blocks != len(embed_channels) or num_blocks != len(heads): raise ValueError("The parameters 'embed_channels', 'blocks' and 'heads' must have equal length.") @@ -307,7 +331,7 @@ def __init__( temporal_size=temporal_size // self.conv_proj.stride[0], ) - # Encoder modules + # Encoder module self.blocks = nn.ModuleList() stage_block_id = 0 pool_countdown = blocks[0] @@ -399,7 +423,7 @@ def _mvitv2( weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -) -> MultiscaleVisionTransformer: +) -> MViTv2: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] @@ -407,7 +431,7 @@ def _mvitv2( # TODO: add min_temporal_size in the meta-data? spatial_size = kwargs.pop("spatial_size", (224, 224)) - model = MultiscaleVisionTransformer( + model = MViTv2( spatial_size=spatial_size, temporal_size=temporal_size, embed_channels=embed_channels, @@ -435,9 +459,28 @@ class MViTV2_B_Weights(WeightsEnum): pass -def mvitv2_t( - *, weights: Optional[MViTV2_T_Weights] = None, progress: bool = True, **kwargs: Any -) -> MultiscaleVisionTransformer: +def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: + """ + Constructs a tiny MViTv2 architecture from + `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + `__. + + Args: + weights (:class:`~torchvision.models.video.MViTV2_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViTV2_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViTV2_T_Weights + :members: + """ weights = MViTV2_T_Weights.verify(weights) return _mvitv2( @@ -453,9 +496,28 @@ def mvitv2_t( ) -def mvitv2_s( - *, weights: Optional[MViTV2_S_Weights] = None, progress: bool = True, **kwargs: Any -) -> MultiscaleVisionTransformer: +def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: + """ + Constructs a tiny MViTv2 architecture from + `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + `__. + + Args: + weights (:class:`~torchvision.models.video.MViTV2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViTV2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViTV2_S_Weights + :members: + """ weights = MViTV2_S_Weights.verify(weights) return _mvitv2( @@ -471,9 +533,28 @@ def mvitv2_s( ) -def mvitv2_b( - *, weights: Optional[MViTV2_B_Weights] = None, progress: bool = True, **kwargs: Any -) -> MultiscaleVisionTransformer: +def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: + """ + Constructs a tiny MViTv2 architecture from + `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection + `__. + + Args: + weights (:class:`~torchvision.models.video.MViTV2_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.video.MViTV2_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.video.MViTV2`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.video.MViTV2_B_Weights + :members: + """ weights = MViTV2_B_Weights.verify(weights) return _mvitv2( From 881565c69b5c781ec69ef39251037ed9490b4178 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 May 2022 09:46:56 +0100 Subject: [PATCH 05/11] Add `min_temporal_size` in the supported meta-data. --- test/test_extended_models.py | 1 + torchvision/models/video/mvitv2.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 396e79c3f6d..05efd39099e 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -87,6 +87,7 @@ def test_schema_meta_validation(model_fn): "license", "_metrics", "min_size", + "min_temporal_size", "num_params", "recipe", "unquantized", diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvitv2.py index 4bbc828d5aa..36183e2c0b0 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvitv2.py @@ -415,7 +415,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _mvitv2( - temporal_size: int, embed_channels: List[int], blocks: List[int], heads: List[int], @@ -428,8 +427,9 @@ def _mvitv2( _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] _ovewrite_named_param(kwargs, "spatial_size", weights.meta["min_size"][0]) - # TODO: add min_temporal_size in the meta-data? + _ovewrite_named_param(kwargs, "temporal_size", weights.meta["min_temporal_size"]) spatial_size = kwargs.pop("spatial_size", (224, 224)) + temporal_size = kwargs.pop("temporal_size", 16) model = MViTv2( spatial_size=spatial_size, From 06a6694e4861e166947af4115bfd385d5975a28a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 May 2022 10:10:40 +0100 Subject: [PATCH 06/11] Switch Tuple[int, int, int] with List[int] to support easier the 2D case --- torchvision/models/video/mvitv2.py | 73 ++++++++++++++++-------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvitv2.py index 36183e2c0b0..f9dbac1244e 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvitv2.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, List, Optional, Sequence, Tuple import torch import torch.fx @@ -94,11 +94,11 @@ def __init__( self, embed_dim: int, num_heads: int, + kernel_q: List[int], + kernel_kv: List[int], + stride_q: List[int], + stride_kv: List[int], dropout: float = 0.0, - kernel_q: Tuple[int, int, int] = (1, 1, 1), - kernel_kv: Tuple[int, int, int] = (1, 1, 1), - stride_q: Tuple[int, int, int] = (1, 1, 1), - stride_kv: Tuple[int, int, int] = (1, 1, 1), norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() @@ -115,14 +115,14 @@ def __init__( self.pool_q: Optional[nn.Module] = None if _prod(kernel_q) > 1 or _prod(stride_q) > 1: - padding_q = cast(Tuple[int, int, int], tuple(int(q // 2) for q in kernel_q)) + padding_q = [int(q // 2) for q in kernel_q] self.pool_q = Pool( nn.Conv3d( self.head_dim, self.head_dim, - kernel_q, - stride=stride_q, - padding=padding_q, + kernel_q, # type: ignore[arg-type] + stride=stride_q, # type: ignore[arg-type] + padding=padding_q, # type: ignore[arg-type] groups=self.head_dim, bias=False, ), @@ -132,14 +132,14 @@ def __init__( self.pool_k: Optional[nn.Module] = None self.pool_v: Optional[nn.Module] = None if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1: - padding_kv = cast(Tuple[int, int, int], tuple(int(kv // 2) for kv in kernel_kv)) + padding_kv = [int(kv // 2) for kv in kernel_kv] self.pool_k = Pool( nn.Conv3d( self.head_dim, self.head_dim, - kernel_kv, - stride=stride_kv, - padding=padding_kv, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] groups=self.head_dim, bias=False, ), @@ -149,9 +149,9 @@ def __init__( nn.Conv3d( self.head_dim, self.head_dim, - kernel_kv, - stride=stride_kv, - padding=padding_kv, + kernel_kv, # type: ignore[arg-type] + stride=stride_kv, # type: ignore[arg-type] + padding=padding_kv, # type: ignore[arg-type] groups=self.head_dim, bias=False, ), @@ -185,21 +185,23 @@ def __init__( input_channels: int, output_channels: int, num_heads: int, + kernel_q: List[int], + kernel_kv: List[int], + stride_q: List[int], + stride_kv: List[int], dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - kernel_q: Tuple[int, int, int] = (1, 1, 1), - kernel_kv: Tuple[int, int, int] = (1, 1, 1), - stride_q: Tuple[int, int, int] = (1, 1, 1), - stride_kv: Tuple[int, int, int] = (1, 1, 1), norm_layer: Callable[..., nn.Module] = nn.LayerNorm, ) -> None: super().__init__() self.pool_skip: Optional[nn.Module] = None if _prod(stride_q) > 1: - kernel_skip = cast(Tuple[int, int, int], tuple(s + 1 if s > 1 else s for s in stride_q)) - padding_skip = cast(Tuple[int, int, int], tuple(int(k // 2) for k in kernel_skip)) - self.pool_skip = Pool(nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None) + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + padding_skip = [int(k // 2) for k in kernel_skip] + self.pool_skip = Pool( + nn.MaxPool3d(kernel_skip, stride=stride_q, padding=padding_skip), None # type: ignore[arg-type] + ) self.norm1 = norm_layer(input_channels) self.norm2 = norm_layer(input_channels) @@ -208,11 +210,11 @@ def __init__( self.attn = MultiscaleAttention( input_channels, num_heads, - dropout=dropout, kernel_q=kernel_q, kernel_kv=kernel_kv, stride_q=stride_q, stride_kv=stride_kv, + dropout=dropout, norm_layer=norm_layer, ) self.mlp = MLP( @@ -270,9 +272,9 @@ def __init__( embed_channels: List[int], blocks: List[int], heads: List[int], - pool_kv_stride: Tuple[int, int, int] = (1, 8, 8), - pool_q_stride: Tuple[int, int, int] = (1, 2, 2), - pool_kvq_kernel: Tuple[int, int, int] = (3, 3, 3), + pool_kv_stride: List[int], + pool_q_stride: List[int], + pool_kvq_kernel: List[int], dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, @@ -289,9 +291,9 @@ def __init__( embed_channels (list of ints): A list with the embedding dimensions of each block group. blocks (list of ints): A list with the number of blocks of each block group. heads (list of ints): A list with the number of heads of each block group. - pool_kv_stride (tuple of ints): The initialize pooling stride of the first block. - pool_q_stride (tuple of ints): The pooling stride which reduces q in each block group. - pool_kvq_kernel (tuple of ints): The pooling kernel for the attention. + pool_kv_stride (list of ints): The initiale pooling stride of the first block. + pool_q_stride (list of ints): The pooling stride which reduces q in each block group. + pool_kvq_kernel (list of ints): The pooling kernel for the attention. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. @@ -343,12 +345,12 @@ def __init__( next_block_index = i + 1 if j + 1 == num_subblocks and i + 1 < num_blocks else i output_channels = embed_channels[next_block_index] - stride_q = (1, 1, 1) + stride_q = [1, 1, 1] if pool_countdown == 0: stride_q = pool_q_stride pool_countdown = blocks[next_block_index] - stride_kv = cast(Tuple[int, int, int], tuple(max(s // stride_q[d], 1) for d, s in enumerate(stride_kv))) + stride_kv = [max(s // stride_q[d], 1) for d, s in enumerate(stride_kv)] # adjust stochastic depth probability based on the depth of the stage block sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) @@ -358,12 +360,12 @@ def __init__( input_channels=input_channels, output_channels=output_channels, num_heads=heads[i], - dropout=attention_dropout, - stochastic_depth_prob=sd_prob, kernel_q=pool_kvq_kernel, kernel_kv=pool_kvq_kernel, stride_q=stride_q, stride_kv=stride_kv, + dropout=attention_dropout, + stochastic_depth_prob=sd_prob, norm_layer=norm_layer, ) ) @@ -437,6 +439,9 @@ def _mvitv2( embed_channels=embed_channels, blocks=blocks, heads=heads, + pool_kv_stride=kwargs.pop("pool_kv_stride", [1, 8, 8]), + pool_q_stride=kwargs.pop("pool_q_stride", [1, 2, 2]), + pool_kvq_kernel=kwargs.pop("pool_kvq_kernel", [3, 3, 3]), stochastic_depth_prob=stochastic_depth_prob, **kwargs, ) From 1330d9ce98b3cef3da942c3f28d4c65338e15c7e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 May 2022 13:06:04 +0100 Subject: [PATCH 07/11] Adding more docs and references --- docs/source/models.rst | 1 + docs/source/models/video_mvitv2.rst | 28 ++++++++++++++++++++++++++++ torchvision/models/video/mvitv2.py | 13 ++++++++----- 3 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 docs/source/models/video_mvitv2.rst diff --git a/docs/source/models.rst b/docs/source/models.rst index b549c25bf94..b6f11ae9db0 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -459,6 +459,7 @@ pre-trained weights: .. toctree:: :maxdepth: 1 + models/video_mvitv2 models/video_resnet | diff --git a/docs/source/models/video_mvitv2.rst b/docs/source/models/video_mvitv2.rst new file mode 100644 index 00000000000..8b7b8ca685f --- /dev/null +++ b/docs/source/models/video_mvitv2.rst @@ -0,0 +1,28 @@ +Video ResNet +============ + +.. currentmodule:: torchvision.models.video + +The MViTv2 model is based on the +`MViTv2: Improved Multiscale Vision Transformers for Classification and Detection +`__ and `Multiscale Vision Transformers +`__ papers. + + +Model builders +-------------- + +The following model builders can be used to instantiate a MViTV2 model, with or +without pre-trained weights. All the model builders internally rely on the +``torchvision.models.video.MViTV2`` base class. Please refer to the `source +code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + mvitv2_t + mvitv2_s + mvitv2_b diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvitv2.py index f9dbac1244e..81a07b746f3 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvitv2.py @@ -468,7 +468,8 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru """ Constructs a tiny MViTv2 architecture from `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__. + `__ and `Multiscale Vision Transformers + `__. Args: weights (:class:`~torchvision.models.video.MViTV2_T_Weights`, optional): The @@ -503,9 +504,10 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: """ - Constructs a tiny MViTv2 architecture from + Constructs a small MViTv2 architecture from `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__. + `__ and `Multiscale Vision Transformers + `__. Args: weights (:class:`~torchvision.models.video.MViTV2_S_Weights`, optional): The @@ -540,9 +542,10 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: """ - Constructs a tiny MViTv2 architecture from + Constructs a base MViTv2 architecture from `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection - `__. + `__ and `Multiscale Vision Transformers + `__. Args: weights (:class:`~torchvision.models.video.MViTV2_B_Weights`, optional): The From 683d520905ea3d0399be780973696ebd80e8b27c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 May 2022 13:18:58 +0100 Subject: [PATCH 08/11] Change naming conventions of classes to follow the same pattern as MobileNetV3 --- docs/source/models/video_mvitv2.rst | 8 +-- ... => ModelTester.test_mvit_v2_b_expect.pkl} | Bin ... => ModelTester.test_mvit_v2_s_expect.pkl} | Bin ... => ModelTester.test_mvit_v2_t_expect.pkl} | Bin test/test_models.py | 6 +- torchvision/models/video/mvitv2.py | 60 ++++++++++-------- 6 files changed, 41 insertions(+), 33 deletions(-) rename test/expect/{ModelTester.test_mvitv2_b_expect.pkl => ModelTester.test_mvit_v2_b_expect.pkl} (100%) rename test/expect/{ModelTester.test_mvitv2_s_expect.pkl => ModelTester.test_mvit_v2_s_expect.pkl} (100%) rename test/expect/{ModelTester.test_mvitv2_t_expect.pkl => ModelTester.test_mvit_v2_t_expect.pkl} (100%) diff --git a/docs/source/models/video_mvitv2.rst b/docs/source/models/video_mvitv2.rst index 8b7b8ca685f..e9ad556ded7 100644 --- a/docs/source/models/video_mvitv2.rst +++ b/docs/source/models/video_mvitv2.rst @@ -3,7 +3,7 @@ Video ResNet .. currentmodule:: torchvision.models.video -The MViTv2 model is based on the +The MViT V2 model is based on the `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection `__ and `Multiscale Vision Transformers `__ papers. @@ -23,6 +23,6 @@ more details about this class. :toctree: generated/ :template: function.rst - mvitv2_t - mvitv2_s - mvitv2_b + mvit_v2_t + mvit_v2_s + mvit_v2_b diff --git a/test/expect/ModelTester.test_mvitv2_b_expect.pkl b/test/expect/ModelTester.test_mvit_v2_b_expect.pkl similarity index 100% rename from test/expect/ModelTester.test_mvitv2_b_expect.pkl rename to test/expect/ModelTester.test_mvit_v2_b_expect.pkl diff --git a/test/expect/ModelTester.test_mvitv2_s_expect.pkl b/test/expect/ModelTester.test_mvit_v2_s_expect.pkl similarity index 100% rename from test/expect/ModelTester.test_mvitv2_s_expect.pkl rename to test/expect/ModelTester.test_mvit_v2_s_expect.pkl diff --git a/test/expect/ModelTester.test_mvitv2_t_expect.pkl b/test/expect/ModelTester.test_mvit_v2_t_expect.pkl similarity index 100% rename from test/expect/ModelTester.test_mvitv2_t_expect.pkl rename to test/expect/ModelTester.test_mvit_v2_t_expect.pkl diff --git a/test/test_models.py b/test/test_models.py index 6e586049008..f3c9f7ed2ef 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -309,13 +309,13 @@ def _check_input_backprop(model, inputs): "image_size": 56, "input_shape": (1, 3, 56, 56), }, - "mvitv2_t": { + "mvit_v2_t": { "input_shape": (1, 3, 16, 224, 224), }, - "mvitv2_s": { + "mvit_v2_s": { "input_shape": (1, 3, 16, 224, 224), }, - "mvitv2_b": { + "mvit_v2_b": { "input_shape": (1, 3, 32, 224, 224), }, } diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvitv2.py index 81a07b746f3..f804a6851df 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvitv2.py @@ -12,7 +12,15 @@ from .._utils import _ovewrite_named_param -__all__ = ["MViTv2", "MViTV2_T_Weights", "MViTV2_S_Weights", "MViTV2_B_Weights", "mvitv2_t", "mvitv2_s", "mvitv2_b"] +__all__ = [ + "MViTV2", + "MViT_V2_T_Weights", + "MViT_V2_S_Weights", + "MViT_V2_B_Weights", + "mvit_v2_t", + "mvit_v2_s", + "mvit_v2_b", +] # TODO: add weights @@ -264,7 +272,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat((class_token, x), dim=1) + pos_embedding -class MViTv2(nn.Module): +class MViTV2(nn.Module): def __init__( self, spatial_size: Tuple[int, int], @@ -283,7 +291,7 @@ def __init__( norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: """ - MViTv2 main class. + MViT V2 main class. Args: spatial_size (tuple of ints): The spacial size of the input as ``(H, W)``. @@ -424,7 +432,7 @@ def _mvitv2( weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -) -> MViTv2: +) -> MViTV2: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) assert weights.meta["min_size"][0] == weights.meta["min_size"][1] @@ -433,7 +441,7 @@ def _mvitv2( spatial_size = kwargs.pop("spatial_size", (224, 224)) temporal_size = kwargs.pop("temporal_size", 16) - model = MViTv2( + model = MViTV2( spatial_size=spatial_size, temporal_size=temporal_size, embed_channels=embed_channels, @@ -452,29 +460,29 @@ def _mvitv2( return model -class MViTV2_T_Weights(WeightsEnum): +class MViT_V2_T_Weights(WeightsEnum): pass -class MViTV2_S_Weights(WeightsEnum): +class MViT_V2_S_Weights(WeightsEnum): pass -class MViTV2_B_Weights(WeightsEnum): +class MViT_V2_B_Weights(WeightsEnum): pass -def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: +def mvit_v2_t(*, weights: Optional[MViT_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: """ - Constructs a tiny MViTv2 architecture from + Constructs a tiny MViTV2 architecture from `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection `__ and `Multiscale Vision Transformers `__. Args: - weights (:class:`~torchvision.models.video.MViTV2_T_Weights`, optional): The + weights (:class:`~torchvision.models.video.MViT_V2_T_Weights`, optional): The pretrained weights to use. See - :class:`~torchvision.models.video.MViTV2_T_Weights` below for + :class:`~torchvision.models.video.MViT_V2_T_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the @@ -484,10 +492,10 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru `_ for more details about this class. - .. autoclass:: torchvision.models.video.MViTV2_T_Weights + .. autoclass:: torchvision.models.video.MViT_V2_T_Weights :members: """ - weights = MViTV2_T_Weights.verify(weights) + weights = MViT_V2_T_Weights.verify(weights) return _mvitv2( spatial_size=(224, 224), @@ -502,17 +510,17 @@ def mvitv2_t(*, weights: Optional[MViTV2_T_Weights] = None, progress: bool = Tru ) -def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: +def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: """ - Constructs a small MViTv2 architecture from + Constructs a small MViTV2 architecture from `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection `__ and `Multiscale Vision Transformers `__. Args: - weights (:class:`~torchvision.models.video.MViTV2_S_Weights`, optional): The + weights (:class:`~torchvision.models.video.MViT_V2_S_Weights`, optional): The pretrained weights to use. See - :class:`~torchvision.models.video.MViTV2_S_Weights` below for + :class:`~torchvision.models.video.MViT_V2_S_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the @@ -522,10 +530,10 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru `_ for more details about this class. - .. autoclass:: torchvision.models.video.MViTV2_S_Weights + .. autoclass:: torchvision.models.video.MViT_V2_S_Weights :members: """ - weights = MViTV2_S_Weights.verify(weights) + weights = MViT_V2_S_Weights.verify(weights) return _mvitv2( spatial_size=(224, 224), @@ -540,17 +548,17 @@ def mvitv2_s(*, weights: Optional[MViTV2_S_Weights] = None, progress: bool = Tru ) -def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTv2: +def mvit_v2_b(*, weights: Optional[MViT_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViTV2: """ - Constructs a base MViTv2 architecture from + Constructs a base MViTV2 architecture from `MViTv2: Improved Multiscale Vision Transformers for Classification and Detection `__ and `Multiscale Vision Transformers `__. Args: - weights (:class:`~torchvision.models.video.MViTV2_B_Weights`, optional): The + weights (:class:`~torchvision.models.video.MViT_V2_B_Weights`, optional): The pretrained weights to use. See - :class:`~torchvision.models.video.MViTV2_B_Weights` below for + :class:`~torchvision.models.video.MViT_V2_B_Weights` below for more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the @@ -560,10 +568,10 @@ def mvitv2_b(*, weights: Optional[MViTV2_B_Weights] = None, progress: bool = Tru `_ for more details about this class. - .. autoclass:: torchvision.models.video.MViTV2_B_Weights + .. autoclass:: torchvision.models.video.MViT_V2_B_Weights :members: """ - weights = MViTV2_B_Weights.verify(weights) + weights = MViT_V2_B_Weights.verify(weights) return _mvitv2( spatial_size=(224, 224), From 677f5a8b2c6ca3bd901b403a5afaf9713f305ac6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 30 May 2022 19:21:27 +0100 Subject: [PATCH 09/11] Fix test breakage. --- test/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_models.py b/test/test_models.py index f3c9f7ed2ef..c8d81216aa0 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -347,7 +347,7 @@ def _check_input_backprop(model, inputs): skipped_big_models = { "vit_h_14", "regnet_y_128gf", - "mvitv2_b", + "mvit_v2_b", } # The following contains configuration and expected values to be used tests that are model specific From 833fb90522900945e96cceeeae9746d88889f3aa Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 31 May 2022 19:53:31 +0100 Subject: [PATCH 10/11] Update todos --- torchvision/models/video/mvitv2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvitv2.py index f804a6851df..dd2d77fb928 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvitv2.py @@ -23,6 +23,8 @@ ] +# TODO: check if we should implement relative pos embedding (Section 4.1 in the paper). Ref: +# https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py#L45 # TODO: add weights # TODO: test on references From 5ec3eff4ad480f36c735cf58fc1569b8a48b5f48 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 1 Jun 2022 16:32:23 +0100 Subject: [PATCH 11/11] Performance optimizations. --- torchvision/models/video/mvitv2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/video/mvitv2.py b/torchvision/models/video/mvitv2.py index dd2d77fb928..206c745f466 100644 --- a/torchvision/models/video/mvitv2.py +++ b/torchvision/models/video/mvitv2.py @@ -80,7 +80,7 @@ def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]) -> Tuple[torch.Ten class_token, x = torch.tensor_split(x, indices=(1,), dim=2) x = x.transpose(2, 3) B, N, C = x.shape[:3] - x = x.reshape((B * N, C) + thw) + x = x.reshape((B * N, C) + thw).contiguous() # normalizing prior pooling is useful when we use BN which can be absorbed to speed up inference if self.norm_before_pool and self.norm_act is not None: @@ -271,7 +271,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: pos_embedding.add_(self.spatial_pos.unsqueeze(0).expand(self.temporal_size, -1, -1).reshape(-1, embed_size)) pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0) class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1) - return torch.cat((class_token, x), dim=1) + pos_embedding + return torch.cat((class_token, x), dim=1).add_(pos_embedding) class MViTV2(nn.Module):