Skip to content

Commit dfcfa20

Browse files
committed
add proposed parallel vit from facebook ai for exploration purposes
1 parent c2b2db2 commit dfcfa20

File tree

4 files changed

+179
-1
lines changed

4 files changed

+179
-1
lines changed

README.md

+41
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
- [Adaptive Token Sampling](#adaptive-token-sampling)
2828
- [Patch Merger](#patch-merger)
2929
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
30+
- [Parallel ViT](#parallel-vit)
3031
- [Dino](#dino)
3132
- [Accessing Attention](#accessing-attention)
3233
- [Research Ideas](#research-ideas)
@@ -240,6 +241,7 @@ preds = v(img) # (1, 1000)
240241
```
241242

242243
## CCT
244+
243245
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
244246

245247
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
@@ -866,6 +868,37 @@ img = torch.randn(4, 3, 256, 256)
866868
tokens = spt(img) # (4, 256, 1024)
867869
```
868870

871+
## Parallel ViT
872+
873+
<img src="./images/parallel-vit.png" width="350px"></img>
874+
875+
This <a href="https://arxiv.org/abs/2203.09795">paper</a> propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.
876+
877+
You can try this variant as follows
878+
879+
```python
880+
import torch
881+
from vit_pytorch.parallel_vit import ViT
882+
883+
v = ViT(
884+
image_size = 256,
885+
patch_size = 16,
886+
num_classes = 1000,
887+
dim = 1024,
888+
depth = 12,
889+
heads = 8,
890+
mlp_dim = 2048,
891+
num_parallel_branches = 2, # in paper, they claimed 2 was optimal
892+
dropout = 0.1,
893+
emb_dropout = 0.1
894+
)
895+
896+
img = torch.randn(4, 3, 256, 256)
897+
898+
preds = v(img) # (4, 1000)
899+
```
900+
901+
869902
## Dino
870903

871904
<img src="./images/dino.png" width="350px"></img>
@@ -1396,6 +1429,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
13961429
}
13971430
```
13981431

1432+
```bibtex
1433+
@inproceedings{Touvron2022ThreeTE,
1434+
title = {Three things everyone should know about Vision Transformers},
1435+
author = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
1436+
year = {2022}
1437+
}
1438+
```
1439+
13991440
```bibtex
14001441
@misc{vaswani2017attention,
14011442
title = {Attention Is All You Need},

images/parallel-vit.png

14.3 KB
Loading

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.28.2',
6+
version = '0.29.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/parallel_vit.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange, repeat
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
# classes
13+
14+
class Parallel(nn.Module):
15+
def __init__(self, *fns):
16+
super().__init__()
17+
self.fns = nn.ModuleList(fns)
18+
19+
def forward(self, x):
20+
return sum([fn(x) for fn in self.fns])
21+
22+
class PreNorm(nn.Module):
23+
def __init__(self, dim, fn):
24+
super().__init__()
25+
self.norm = nn.LayerNorm(dim)
26+
self.fn = fn
27+
def forward(self, x, **kwargs):
28+
return self.fn(self.norm(x), **kwargs)
29+
30+
class FeedForward(nn.Module):
31+
def __init__(self, dim, hidden_dim, dropout = 0.):
32+
super().__init__()
33+
self.net = nn.Sequential(
34+
nn.Linear(dim, hidden_dim),
35+
nn.GELU(),
36+
nn.Dropout(dropout),
37+
nn.Linear(hidden_dim, dim),
38+
nn.Dropout(dropout)
39+
)
40+
def forward(self, x):
41+
return self.net(x)
42+
43+
class Attention(nn.Module):
44+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
45+
super().__init__()
46+
inner_dim = dim_head * heads
47+
project_out = not (heads == 1 and dim_head == dim)
48+
49+
self.heads = heads
50+
self.scale = dim_head ** -0.5
51+
52+
self.attend = nn.Softmax(dim = -1)
53+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
54+
55+
self.to_out = nn.Sequential(
56+
nn.Linear(inner_dim, dim),
57+
nn.Dropout(dropout)
58+
) if project_out else nn.Identity()
59+
60+
def forward(self, x):
61+
qkv = self.to_qkv(x).chunk(3, dim = -1)
62+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
63+
64+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
65+
66+
attn = self.attend(dots)
67+
68+
out = torch.matmul(attn, v)
69+
out = rearrange(out, 'b h n d -> b n (h d)')
70+
return self.to_out(out)
71+
72+
class Transformer(nn.Module):
73+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
74+
super().__init__()
75+
self.layers = nn.ModuleList([])
76+
77+
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
78+
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
79+
80+
for _ in range(depth):
81+
self.layers.append(nn.ModuleList([
82+
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
83+
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
84+
]))
85+
86+
def forward(self, x):
87+
for attns, ffs in self.layers:
88+
x = attns(x) + x
89+
x = ffs(x) + x
90+
return x
91+
92+
class ViT(nn.Module):
93+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
94+
super().__init__()
95+
image_height, image_width = pair(image_size)
96+
patch_height, patch_width = pair(patch_size)
97+
98+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
99+
100+
num_patches = (image_height // patch_height) * (image_width // patch_width)
101+
patch_dim = channels * patch_height * patch_width
102+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
103+
104+
self.to_patch_embedding = nn.Sequential(
105+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
106+
nn.Linear(patch_dim, dim),
107+
)
108+
109+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
110+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
111+
self.dropout = nn.Dropout(emb_dropout)
112+
113+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
114+
115+
self.pool = pool
116+
self.to_latent = nn.Identity()
117+
118+
self.mlp_head = nn.Sequential(
119+
nn.LayerNorm(dim),
120+
nn.Linear(dim, num_classes)
121+
)
122+
123+
def forward(self, img):
124+
x = self.to_patch_embedding(img)
125+
b, n, _ = x.shape
126+
127+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
128+
x = torch.cat((cls_tokens, x), dim=1)
129+
x += self.pos_embedding[:, :(n + 1)]
130+
x = self.dropout(x)
131+
132+
x = self.transformer(x)
133+
134+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
135+
136+
x = self.to_latent(x)
137+
return self.mlp_head(x)

0 commit comments

Comments
 (0)