-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding ViT to torchvision/models #4594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
fbd0024
[vit] Adding ViT to torchvision/models
yiwen-song 7521ffe
adding pre-logits layer + resolving comments
yiwen-song 7e63685
Merge branch 'pytorch:main' into main
yiwen-song 2dd878a
Merge branch 'pytorch:main' into main
yiwen-song 53b6967
Fix the model attribute bug
yiwen-song fe248f0
Merge branch 'main' of https://github.com/sallysyw/vision into main
yiwen-song a84361a
Change version to arch
yiwen-song f981519
Merge branch 'pytorch:main' into main
yiwen-song 9d2ef95
Merge branch 'main' into main
datumbox 0aaac5b
Merge branch 'pytorch:main' into main
yiwen-song 1cf8b92
Merge branch 'pytorch:main' into main
yiwen-song c2f3826
fix failing unittests
yiwen-song 35c1d22
remove useless prints
yiwen-song 1aff5cd
Merge branch 'pytorch:main' into main
yiwen-song 568c560
reduce input size to fix unittests
yiwen-song 8e71e4b
Increase windows-cpu executor to 2xlarge
yiwen-song f9860ec
Use `batch_first=True` and remove classifier
yiwen-song 4d7d7fe
Merge branch 'pytorch:main' into main
yiwen-song b795e85
Change resource_class back to xlarge
yiwen-song ff64591
Remove vit_h_14
yiwen-song bd3a747
Remove vit_h_14 from __all__
yiwen-song 8f88592
Move vision_transformer.py into prototype
yiwen-song 22025ac
Fix formatting issue
yiwen-song 26bc529
remove arch in builder
yiwen-song cc22238
Fix type err in model builder
yiwen-song 1d4e2aa
Merge branch 'main' into main
yiwen-song 091bf6b
Merge branch 'pytorch:main' into main
yiwen-song 41edd15
address comments and trigger unittests
yiwen-song 48ce69e
remove the prototype import in torchvision.models
yiwen-song 0caf745
Merge branch 'main' into main
yiwen-song 3a6b445
Adding vit back to models to trigger CircleCI test
yiwen-song 72c5af7
fix test_jit_forward_backward
yiwen-song aae308c
Move all to prototype.
datumbox 7b1e59e
Merge branch 'main' into main
datumbox 717b6af
Merge branch 'main' into main
datumbox f0df7f8
Adopt new helper methods and fix prototype tests.
datumbox 3807b23
Remove unused import.
datumbox eabec95
Merge branch 'main' into main
yiwen-song 40b566b
Merge branch 'main' into main
yiwen-song File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,307 @@ | ||
# Implement ViT from: | ||
# https://arxiv.org/abs/2010.11929 | ||
|
||
# References: | ||
# https://github.com/google-research/vision_transformer | ||
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py | ||
|
||
import math | ||
from collections import OrderedDict | ||
from functools import partial | ||
from typing import Any | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
|
||
__all__ = [ | ||
"VisionTransformer", | ||
"vit_b_16", | ||
"vit_b_32", | ||
"vit_l_16", | ||
"vit_l_32", | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
] | ||
|
||
|
||
LayerNorm = partial(nn.LayerNorm, eps=1e-6) | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class MLPBlock(nn.Sequential): | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Transformer MLP block.""" | ||
|
||
def __init__(self, in_dim: int, mlp_dim: int, dropout_rate: float): | ||
super().__init__() | ||
self.linear_1 = nn.Linear(in_dim, mlp_dim) | ||
self.act = nn.GELU() | ||
self.dropout_1 = nn.Dropout(dropout_rate) | ||
self.linear_2 = nn.Linear(mlp_dim, in_dim) | ||
self.dropout_2 = nn.Dropout(dropout_rate) | ||
self.init_weights() | ||
|
||
def init_weights(self): | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nn.init.xavier_uniform_(self.linear_1.weight) | ||
nn.init.xavier_uniform_(self.linear_2.weight) | ||
nn.init.normal_(self.linear_1.bias, std=1e-6) | ||
nn.init.normal_(self.linear_2.bias, std=1e-6) | ||
|
||
|
||
class EncoderBlock(nn.Module): | ||
"""Transformer encoder block.""" | ||
|
||
def __init__( | ||
self, num_heads: int, hidden_dim: int, mlp_dim: int, dropout_rate: float, attention_dropout_rate: float | ||
): | ||
super().__init__() | ||
self.num_heads = num_heads | ||
|
||
# Attention block | ||
self.ln_1 = LayerNorm(hidden_dim) | ||
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout_rate) | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.dropout = nn.Dropout(dropout_rate) | ||
|
||
# MLP block | ||
self.ln_2 = LayerNorm(hidden_dim) | ||
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout_rate) | ||
|
||
def forward(self, input: Tensor): | ||
# assert input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}" | ||
x = self.ln_1(input) | ||
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) | ||
x = self.dropout(x) | ||
x = x + input | ||
|
||
y = self.ln_2(x) | ||
y = self.mlp(y) | ||
return x + y | ||
|
||
|
||
class Encoder(nn.Module): | ||
"""Transformer Model Encoder for sequence to sequence translation.""" | ||
|
||
def __init__( | ||
self, | ||
seq_length: int, | ||
num_layers: int, | ||
num_heads: int, | ||
hidden_dim: int, | ||
mlp_dim: int, | ||
dropout_rate: float, | ||
attention_dropout_rate: float, | ||
): | ||
super().__init__() | ||
# Note that batch_size is on the second dim because | ||
# we have batch_first=False in nn.MultiAttention() by default | ||
self.pos_embedding = nn.Parameter(torch.empty(seq_length, 1, hidden_dim).normal_(std=0.02)) # from BERT | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.dropout = nn.Dropout(dropout_rate) | ||
layers: OrderedDict[str, nn.Module] = OrderedDict() | ||
for i in range(num_layers): | ||
layers[f"encoder_layer_{i}"] = EncoderBlock( | ||
num_heads, | ||
hidden_dim, | ||
mlp_dim, | ||
dropout_rate, | ||
attention_dropout_rate, | ||
) | ||
self.layers = nn.Sequential(layers) | ||
self.ln = LayerNorm(hidden_dim) | ||
|
||
def forward(self, input: Tensor): | ||
# assert input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}" | ||
input = input + self.pos_embedding | ||
return self.ln(self.layers(self.dropout(input))) | ||
|
||
|
||
class VisionTransformer(nn.Module): | ||
"""Vision Transformer as per https://arxiv.org/abs/2010.11929.""" | ||
|
||
def __init__( | ||
self, | ||
image_size: int, | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
patch_size: int, | ||
num_layers: int, | ||
num_heads: int, | ||
hidden_dim: int, | ||
mlp_dim: int, | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dropout_rate: float = 0.0, | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
attention_dropout_rate: float = 0.0, | ||
classifier: str = "token", | ||
num_classes: int = 1000, | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
super().__init__() | ||
# assert image_size % patch_size == 0, "Input shape indivisible by patch size!" | ||
# assert classifier in ["token", "gap"], "Unexpected classifier mode!" | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.image_size = image_size | ||
self.patch_size = patch_size | ||
self.hidden_dim = hidden_dim | ||
self.mlp_dim = mlp_dim | ||
self.attention_dropout_rate = attention_dropout_rate | ||
self.dropout_rate = dropout_rate | ||
self.classifier = classifier | ||
self.num_classes = num_classes | ||
|
||
input_channels = 3 | ||
|
||
# The conv_proj is a more efficient version of reshaping, permuting | ||
# and projecting the input | ||
self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) | ||
|
||
seq_length = (image_size // patch_size) ** 2 | ||
if self.classifier == "token": | ||
# Add a class token | ||
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) | ||
seq_length += 1 | ||
|
||
self.encoder = Encoder( | ||
seq_length, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
mlp_dim, | ||
dropout_rate, | ||
attention_dropout_rate, | ||
) | ||
self.seq_length = seq_length | ||
|
||
self.head = nn.Linear(hidden_dim, num_classes) | ||
self.init_weights() | ||
|
||
def init_weights(self): | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] | ||
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) | ||
nn.init.zeros_(self.conv_proj.bias) | ||
nn.init.zeros_(self.head.weight) | ||
|
||
def forward(self, x: torch.Tensor): | ||
n, c, h, w = x.shape | ||
p = self.patch_size | ||
# assert h == w == self.image_size | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n_h = h // p | ||
n_w = w // p | ||
|
||
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w) | ||
x = self.conv_proj(x) | ||
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) | ||
x = x.reshape(n, self.hidden_dim, n_h * n_w) | ||
|
||
# (n, hidden_dim, (n_h * n_w)) -> ((n_h * n_w), n, hidden_dim) | ||
# The self attention layer expects inputs in the format (S, N, E) | ||
# where S is the source sequence length, N is the batch size, E is the | ||
# embedding dimension | ||
x = x.permute(2, 0, 1) | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if self.classifier == "token": | ||
# Expand the class token to the full batch. | ||
batch_class_token = self.class_token.expand(-1, n, -1) | ||
x = torch.cat([batch_class_token, x], dim=0) | ||
|
||
x = self.encoder(x) | ||
|
||
if self.classifier == "token": | ||
# Classifier as used by standard language architectures | ||
x = x[0, :, :] | ||
elif self.classifier == "gap": | ||
# Classifier as used by standard vision architectures | ||
x = x.mean(dim=0) | ||
else: | ||
raise ValueError(f"Invalid classifier={self.classifier}") | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
x = self.head(x) | ||
|
||
return x | ||
|
||
|
||
def _vision_transformer(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> VisionTransformer: | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if kwargs.get("image_size", None) is None: | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model = VisionTransformer(image_size=224, **kwargs) | ||
else: | ||
model = VisionTransformer(**kwargs) | ||
# TODO: Adding pre-trained models | ||
return model | ||
|
||
|
||
def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | ||
""" | ||
Constructs a ViT_b_16 architecture from | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
yiwen-song marked this conversation as resolved.
Show resolved
Hide resolved
|
||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _vision_transformer( | ||
version="b_16", | ||
pretrained=pretrained, | ||
progress=progress, | ||
patch_size=16, | ||
num_layers=12, | ||
num_heads=12, | ||
hidden_dim=768, | ||
mlp_dim=3072, | ||
**kwargs, | ||
) | ||
|
||
|
||
def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | ||
""" | ||
Constructs a ViT_b_32 architecture from | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _vision_transformer( | ||
version="b_32", | ||
pretrained=pretrained, | ||
progress=progress, | ||
patch_size=32, | ||
num_layers=12, | ||
num_heads=12, | ||
hidden_dim=768, | ||
mlp_dim=3072, | ||
**kwargs, | ||
) | ||
|
||
|
||
def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | ||
""" | ||
Constructs a ViT_l_16 architecture from | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _vision_transformer( | ||
version="l_16", | ||
pretrained=pretrained, | ||
progress=progress, | ||
patch_size=16, | ||
num_layers=24, | ||
num_heads=16, | ||
hidden_dim=1024, | ||
mlp_dim=4096, | ||
**kwargs, | ||
) | ||
|
||
|
||
def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | ||
""" | ||
Constructs a ViT_l_32 architecture from | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | ||
|
||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
return _vision_transformer( | ||
version="l_32", | ||
pretrained=pretrained, | ||
progress=progress, | ||
patch_size=32, | ||
num_layers=24, | ||
num_heads=16, | ||
hidden_dim=1024, | ||
mlp_dim=4096, | ||
**kwargs, | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.