Skip to content

Commit 13b313b

Browse files
committed
add a simple vit flavor for a new bytedance paper that proposes to break out of the traditional one residual stream architecture - "hyper-connections"
1 parent 56373c0 commit 13b313b

File tree

3 files changed

+245
-1
lines changed

3 files changed

+245
-1
lines changed

Diff for: README.md

+11
Original file line numberDiff line numberDiff line change
@@ -2161,4 +2161,15 @@ Coming from computer vision and new to transformers? Here are some resources tha
21612161
}
21622162
```
21632163

2164+
```bibtex
2165+
@article{Zhu2024HyperConnections,
2166+
title = {Hyper-Connections},
2167+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
2168+
journal = {ArXiv},
2169+
year = {2024},
2170+
volume = {abs/2409.19606},
2171+
url = {https://api.semanticscholar.org/CorpusID:272987528}
2172+
}
2173+
```
2174+
21642175
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.9',
9+
version = '1.9.0',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

Diff for: vit_pytorch/simple_vit_with_hyper_connections.py

+233
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""
2+
ViT + Hyper-Connections + Register Tokens
3+
https://arxiv.org/abs/2409.19606
4+
"""
5+
6+
import torch
7+
from torch import nn, tensor
8+
from torch.nn import Module, ModuleList
9+
10+
from einops import rearrange, repeat, reduce, einsum, pack, unpack
11+
from einops.layers.torch import Rearrange
12+
13+
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
14+
15+
# helpers
16+
17+
def pair(t):
18+
return t if isinstance(t, tuple) else (t, t)
19+
20+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
21+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
22+
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
23+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
24+
omega = 1.0 / (temperature ** omega)
25+
26+
y = y.flatten()[:, None] * omega[None, :]
27+
x = x.flatten()[:, None] * omega[None, :]
28+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
29+
return pe.type(dtype)
30+
31+
# hyper connections
32+
33+
class HyperConnection(Module):
34+
def __init__(
35+
self,
36+
dim,
37+
num_residual_streams,
38+
layer_index
39+
):
40+
""" Appendix J - Algorithm 2, Dynamic only """
41+
super().__init__()
42+
43+
self.norm = nn.LayerNorm(dim, bias = False)
44+
45+
self.num_residual_streams = num_residual_streams
46+
self.layer_index = layer_index
47+
48+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
49+
50+
init_alpha0 = torch.zeros((num_residual_streams, 1))
51+
init_alpha0[layer_index % num_residual_streams, 0] = 1.
52+
53+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
54+
55+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
56+
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
57+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
58+
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
59+
60+
def width_connection(self, residuals):
61+
normed = self.norm(residuals)
62+
63+
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
64+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
65+
alpha = dynamic_alpha + self.static_alpha
66+
67+
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
68+
dynamic_beta = dc_weight * self.dynamic_beta_scale
69+
beta = dynamic_beta + self.static_beta
70+
71+
# width connection
72+
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
73+
74+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
75+
76+
return branch_input, residuals, beta
77+
78+
def depth_connection(
79+
self,
80+
residuals,
81+
branch_output,
82+
beta
83+
):
84+
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
85+
86+
# classes
87+
88+
class FeedForward(Module):
89+
def __init__(self, dim, hidden_dim):
90+
super().__init__()
91+
self.net = nn.Sequential(
92+
nn.LayerNorm(dim),
93+
nn.Linear(dim, hidden_dim),
94+
nn.GELU(),
95+
nn.Linear(hidden_dim, dim),
96+
)
97+
def forward(self, x):
98+
return self.net(x)
99+
100+
class Attention(Module):
101+
def __init__(self, dim, heads = 8, dim_head = 64):
102+
super().__init__()
103+
inner_dim = dim_head * heads
104+
self.heads = heads
105+
self.scale = dim_head ** -0.5
106+
self.norm = nn.LayerNorm(dim)
107+
108+
self.attend = nn.Softmax(dim = -1)
109+
110+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
111+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
112+
113+
def forward(self, x):
114+
x = self.norm(x)
115+
116+
qkv = self.to_qkv(x).chunk(3, dim = -1)
117+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
118+
119+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
120+
121+
attn = self.attend(dots)
122+
123+
out = torch.matmul(attn, v)
124+
out = rearrange(out, 'b h n d -> b n (h d)')
125+
return self.to_out(out)
126+
127+
class Transformer(Module):
128+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
129+
super().__init__()
130+
131+
self.num_residual_streams = num_residual_streams
132+
133+
self.norm = nn.LayerNorm(dim)
134+
self.layers = ModuleList([])
135+
136+
for layer_index in range(depth):
137+
self.layers.append(nn.ModuleList([
138+
HyperConnection(dim, num_residual_streams, layer_index),
139+
Attention(dim, heads = heads, dim_head = dim_head),
140+
HyperConnection(dim, num_residual_streams, layer_index),
141+
FeedForward(dim, mlp_dim)
142+
]))
143+
144+
def forward(self, x):
145+
146+
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
147+
148+
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
149+
150+
x, attn_res, beta = attn_hyper_conn.width_connection(x)
151+
152+
x = attn(x)
153+
154+
x = attn_hyper_conn.depth_connection(attn_res, x, beta)
155+
156+
x, ff_res, beta = ff_hyper_conn.width_connection(x)
157+
158+
x = ff(x)
159+
160+
x = ff_hyper_conn.depth_connection(ff_res, x, beta)
161+
162+
x = reduce(x, 'b n e d -> b n d', 'sum')
163+
164+
return self.norm(x)
165+
166+
class SimpleViT(nn.Module):
167+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
168+
super().__init__()
169+
image_height, image_width = pair(image_size)
170+
patch_height, patch_width = pair(patch_size)
171+
172+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
173+
174+
patch_dim = channels * patch_height * patch_width
175+
176+
self.to_patch_embedding = nn.Sequential(
177+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
178+
nn.LayerNorm(patch_dim),
179+
nn.Linear(patch_dim, dim),
180+
nn.LayerNorm(dim),
181+
)
182+
183+
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
184+
185+
self.pos_embedding = posemb_sincos_2d(
186+
h = image_height // patch_height,
187+
w = image_width // patch_width,
188+
dim = dim,
189+
)
190+
191+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
192+
193+
self.pool = "mean"
194+
self.to_latent = nn.Identity()
195+
196+
self.linear_head = nn.Linear(dim, num_classes)
197+
198+
def forward(self, img):
199+
batch, device = img.shape[0], img.device
200+
201+
x = self.to_patch_embedding(img)
202+
x += self.pos_embedding.to(x)
203+
204+
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
205+
206+
x, ps = pack([x, r], 'b * d')
207+
208+
x = self.transformer(x)
209+
210+
x, _ = unpack(x, ps, 'b * d')
211+
212+
x = x.mean(dim = 1)
213+
214+
x = self.to_latent(x)
215+
return self.linear_head(x)
216+
217+
# main
218+
219+
if __name__ == '__main__':
220+
vit = SimpleViT(
221+
num_classes = 1000,
222+
image_size = 256,
223+
patch_size = 8,
224+
dim = 1024,
225+
depth = 12,
226+
heads = 8,
227+
mlp_dim = 2048,
228+
num_residual_streams = 8
229+
)
230+
231+
images = torch.randn(3, 3, 256, 256)
232+
233+
logits = vit(images)

0 commit comments

Comments
 (0)