Skip to content

Commit b3e90a2

Browse files
committed
add simple vit, from https://arxiv.org/abs/2205.01580
1 parent 4ef72fc commit b3e90a2

File tree

4 files changed

+157
-1
lines changed

4 files changed

+157
-1
lines changed

Diff for: README.md

+38
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- [Install](#install)
77
- [Usage](#usage)
88
- [Parameters](#parameters)
9+
- [Simple ViT](#simple-vit)
910
- [Distillation](#distillation)
1011
- [Deep ViT](#deep-vit)
1112
- [CaiT](#cait)
@@ -106,6 +107,33 @@ Embedding dropout rate.
106107
- `pool`: string, either `cls` token pooling or `mean` pooling
107108

108109

110+
## Simple ViT
111+
112+
<a href="https://arxiv.org/abs/2205.01580">An update</a> from some of the same authors of the original paper proposes simplifications to `ViT` that allows it to train faster and better.
113+
114+
Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head
115+
116+
You can use it by importing the `SimpleViT` as shown below
117+
118+
```python
119+
import torch
120+
from vit_pytorch import SimpleViT
121+
122+
v = SimpleViT(
123+
image_size = 256,
124+
patch_size = 32,
125+
num_classes = 1000,
126+
dim = 1024,
127+
depth = 6,
128+
heads = 16,
129+
mlp_dim = 2048
130+
)
131+
132+
img = torch.randn(1, 3, 256, 256)
133+
134+
preds = v(img) # (1, 1000)
135+
```
136+
109137
## Distillation
110138

111139
<img src="./images/distill.png" width="300px"></img>
@@ -1669,6 +1697,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
16691697
}
16701698
```
16711699

1700+
```bibtex
1701+
@misc{Beyer2022BetterPlainViT
1702+
title = {Better plain ViT baselines for ImageNet-1k},
1703+
author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
1704+
publisher = {arXiv},
1705+
year = {2022}
1706+
}
1707+
1708+
```
1709+
16721710
```bibtex
16731711
@misc{vaswani2017attention,
16741712
title = {Attention Is All You Need},

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.34.1',
6+
version = '0.35.2',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

Diff for: vit_pytorch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from vit_pytorch.vit import ViT
2+
from vit_pytorch.simple_vit import SimpleViT
3+
24
from vit_pytorch.mae import MAE
35
from vit_pytorch.dino import Dino

Diff for: vit_pytorch/simple_vit.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
13+
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
14+
15+
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
16+
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
17+
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
18+
omega = 1. / (temperature ** omega)
19+
20+
y = y.flatten()[:, None] * omega[None, :]
21+
x = x.flatten()[:, None] * omega[None, :]
22+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
23+
return pe.type(dtype)
24+
25+
# classes
26+
27+
class FeedForward(nn.Module):
28+
def __init__(self, dim, hidden_dim):
29+
super().__init__()
30+
self.net = nn.Sequential(
31+
nn.LayerNorm(dim),
32+
nn.Linear(dim, hidden_dim),
33+
nn.GELU(),
34+
nn.Linear(hidden_dim, dim),
35+
)
36+
def forward(self, x):
37+
return self.net(x)
38+
39+
class Attention(nn.Module):
40+
def __init__(self, dim, heads = 8, dim_head = 64):
41+
super().__init__()
42+
inner_dim = dim_head * heads
43+
self.heads = heads
44+
self.scale = dim_head ** -0.5
45+
self.norm = nn.LayerNorm(dim)
46+
47+
self.attend = nn.Softmax(dim = -1)
48+
49+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
50+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
51+
52+
def forward(self, x):
53+
x = self.norm(x)
54+
55+
qkv = self.to_qkv(x).chunk(3, dim = -1)
56+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
57+
58+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
59+
60+
attn = self.attend(dots)
61+
62+
out = torch.matmul(attn, v)
63+
out = rearrange(out, 'b h n d -> b n (h d)')
64+
return self.to_out(out)
65+
66+
class Transformer(nn.Module):
67+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
68+
super().__init__()
69+
self.layers = nn.ModuleList([])
70+
for _ in range(depth):
71+
self.layers.append(nn.ModuleList([
72+
Attention(dim, heads = heads, dim_head = dim_head),
73+
FeedForward(dim, mlp_dim)
74+
]))
75+
def forward(self, x):
76+
for attn, ff in self.layers:
77+
x = attn(x) + x
78+
x = ff(x) + x
79+
return x
80+
81+
class SimpleViT(nn.Module):
82+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
83+
super().__init__()
84+
image_height, image_width = pair(image_size)
85+
patch_height, patch_width = pair(patch_size)
86+
87+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
88+
89+
num_patches = (image_height // patch_height) * (image_width // patch_width)
90+
patch_dim = channels * patch_height * patch_width
91+
92+
self.to_patch_embedding = nn.Sequential(
93+
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
94+
nn.Linear(patch_dim, dim),
95+
)
96+
97+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
98+
99+
self.to_latent = nn.Identity()
100+
self.linear_head = nn.Sequential(
101+
nn.LayerNorm(dim),
102+
nn.Linear(dim, num_classes)
103+
)
104+
105+
def forward(self, img):
106+
*_, h, w, dtype = *img.shape, img.dtype
107+
108+
x = self.to_patch_embedding(img)
109+
pe = posemb_sincos_2d(x)
110+
x = rearrange(x, 'b ... d -> b (...) d') + pe
111+
112+
x = self.transformer(x)
113+
x = x.mean(dim = 1)
114+
115+
x = self.to_latent(x)
116+
return self.linear_head(x)

0 commit comments

Comments
 (0)