|
| 1 | +""" |
| 2 | +ViT + Hyper-Connections + Register Tokens |
| 3 | +https://arxiv.org/abs/2409.19606 |
| 4 | +""" |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch import nn, tensor |
| 8 | +from torch.nn import Module, ModuleList |
| 9 | + |
| 10 | +from einops import rearrange, repeat, reduce, einsum, pack, unpack |
| 11 | +from einops.layers.torch import Rearrange |
| 12 | + |
| 13 | +# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension |
| 14 | + |
| 15 | +# helpers |
| 16 | + |
| 17 | +def pair(t): |
| 18 | + return t if isinstance(t, tuple) else (t, t) |
| 19 | + |
| 20 | +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): |
| 21 | + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") |
| 22 | + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" |
| 23 | + omega = torch.arange(dim // 4) / (dim // 4 - 1) |
| 24 | + omega = 1.0 / (temperature ** omega) |
| 25 | + |
| 26 | + y = y.flatten()[:, None] * omega[None, :] |
| 27 | + x = x.flatten()[:, None] * omega[None, :] |
| 28 | + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) |
| 29 | + return pe.type(dtype) |
| 30 | + |
| 31 | +# hyper connections |
| 32 | + |
| 33 | +class HyperConnection(Module): |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + dim, |
| 37 | + num_residual_streams, |
| 38 | + layer_index |
| 39 | + ): |
| 40 | + """ Appendix J - Algorithm 2, Dynamic only """ |
| 41 | + super().__init__() |
| 42 | + |
| 43 | + self.norm = nn.LayerNorm(dim, bias = False) |
| 44 | + |
| 45 | + self.num_residual_streams = num_residual_streams |
| 46 | + self.layer_index = layer_index |
| 47 | + |
| 48 | + self.static_beta = nn.Parameter(torch.ones(num_residual_streams)) |
| 49 | + |
| 50 | + init_alpha0 = torch.zeros((num_residual_streams, 1)) |
| 51 | + init_alpha0[layer_index % num_residual_streams, 0] = 1. |
| 52 | + |
| 53 | + self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1)) |
| 54 | + |
| 55 | + self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1)) |
| 56 | + self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2)) |
| 57 | + self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim)) |
| 58 | + self.dynamic_beta_scale = nn.Parameter(tensor(1e-2)) |
| 59 | + |
| 60 | + def width_connection(self, residuals): |
| 61 | + normed = self.norm(residuals) |
| 62 | + |
| 63 | + wc_weight = (normed @ self.dynamic_alpha_fn).tanh() |
| 64 | + dynamic_alpha = wc_weight * self.dynamic_alpha_scale |
| 65 | + alpha = dynamic_alpha + self.static_alpha |
| 66 | + |
| 67 | + dc_weight = (normed @ self.dynamic_beta_fn).tanh() |
| 68 | + dynamic_beta = dc_weight * self.dynamic_beta_scale |
| 69 | + beta = dynamic_beta + self.static_beta |
| 70 | + |
| 71 | + # width connection |
| 72 | + mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d') |
| 73 | + |
| 74 | + branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :] |
| 75 | + |
| 76 | + return branch_input, residuals, beta |
| 77 | + |
| 78 | + def depth_connection( |
| 79 | + self, |
| 80 | + residuals, |
| 81 | + branch_output, |
| 82 | + beta |
| 83 | + ): |
| 84 | + return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals |
| 85 | + |
| 86 | +# classes |
| 87 | + |
| 88 | +class FeedForward(Module): |
| 89 | + def __init__(self, dim, hidden_dim): |
| 90 | + super().__init__() |
| 91 | + self.net = nn.Sequential( |
| 92 | + nn.LayerNorm(dim), |
| 93 | + nn.Linear(dim, hidden_dim), |
| 94 | + nn.GELU(), |
| 95 | + nn.Linear(hidden_dim, dim), |
| 96 | + ) |
| 97 | + def forward(self, x): |
| 98 | + return self.net(x) |
| 99 | + |
| 100 | +class Attention(Module): |
| 101 | + def __init__(self, dim, heads = 8, dim_head = 64): |
| 102 | + super().__init__() |
| 103 | + inner_dim = dim_head * heads |
| 104 | + self.heads = heads |
| 105 | + self.scale = dim_head ** -0.5 |
| 106 | + self.norm = nn.LayerNorm(dim) |
| 107 | + |
| 108 | + self.attend = nn.Softmax(dim = -1) |
| 109 | + |
| 110 | + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| 111 | + self.to_out = nn.Linear(inner_dim, dim, bias = False) |
| 112 | + |
| 113 | + def forward(self, x): |
| 114 | + x = self.norm(x) |
| 115 | + |
| 116 | + qkv = self.to_qkv(x).chunk(3, dim = -1) |
| 117 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
| 118 | + |
| 119 | + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| 120 | + |
| 121 | + attn = self.attend(dots) |
| 122 | + |
| 123 | + out = torch.matmul(attn, v) |
| 124 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 125 | + return self.to_out(out) |
| 126 | + |
| 127 | +class Transformer(Module): |
| 128 | + def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams): |
| 129 | + super().__init__() |
| 130 | + |
| 131 | + self.num_residual_streams = num_residual_streams |
| 132 | + |
| 133 | + self.norm = nn.LayerNorm(dim) |
| 134 | + self.layers = ModuleList([]) |
| 135 | + |
| 136 | + for layer_index in range(depth): |
| 137 | + self.layers.append(nn.ModuleList([ |
| 138 | + HyperConnection(dim, num_residual_streams, layer_index), |
| 139 | + Attention(dim, heads = heads, dim_head = dim_head), |
| 140 | + HyperConnection(dim, num_residual_streams, layer_index), |
| 141 | + FeedForward(dim, mlp_dim) |
| 142 | + ])) |
| 143 | + |
| 144 | + def forward(self, x): |
| 145 | + |
| 146 | + x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams) |
| 147 | + |
| 148 | + for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers: |
| 149 | + |
| 150 | + x, attn_res, beta = attn_hyper_conn.width_connection(x) |
| 151 | + |
| 152 | + x = attn(x) |
| 153 | + |
| 154 | + x = attn_hyper_conn.depth_connection(attn_res, x, beta) |
| 155 | + |
| 156 | + x, ff_res, beta = ff_hyper_conn.width_connection(x) |
| 157 | + |
| 158 | + x = ff(x) |
| 159 | + |
| 160 | + x = ff_hyper_conn.depth_connection(ff_res, x, beta) |
| 161 | + |
| 162 | + x = reduce(x, 'b n e d -> b n d', 'sum') |
| 163 | + |
| 164 | + return self.norm(x) |
| 165 | + |
| 166 | +class SimpleViT(nn.Module): |
| 167 | + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64): |
| 168 | + super().__init__() |
| 169 | + image_height, image_width = pair(image_size) |
| 170 | + patch_height, patch_width = pair(patch_size) |
| 171 | + |
| 172 | + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' |
| 173 | + |
| 174 | + patch_dim = channels * patch_height * patch_width |
| 175 | + |
| 176 | + self.to_patch_embedding = nn.Sequential( |
| 177 | + Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), |
| 178 | + nn.LayerNorm(patch_dim), |
| 179 | + nn.Linear(patch_dim, dim), |
| 180 | + nn.LayerNorm(dim), |
| 181 | + ) |
| 182 | + |
| 183 | + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) |
| 184 | + |
| 185 | + self.pos_embedding = posemb_sincos_2d( |
| 186 | + h = image_height // patch_height, |
| 187 | + w = image_width // patch_width, |
| 188 | + dim = dim, |
| 189 | + ) |
| 190 | + |
| 191 | + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams) |
| 192 | + |
| 193 | + self.pool = "mean" |
| 194 | + self.to_latent = nn.Identity() |
| 195 | + |
| 196 | + self.linear_head = nn.Linear(dim, num_classes) |
| 197 | + |
| 198 | + def forward(self, img): |
| 199 | + batch, device = img.shape[0], img.device |
| 200 | + |
| 201 | + x = self.to_patch_embedding(img) |
| 202 | + x += self.pos_embedding.to(x) |
| 203 | + |
| 204 | + r = repeat(self.register_tokens, 'n d -> b n d', b = batch) |
| 205 | + |
| 206 | + x, ps = pack([x, r], 'b * d') |
| 207 | + |
| 208 | + x = self.transformer(x) |
| 209 | + |
| 210 | + x, _ = unpack(x, ps, 'b * d') |
| 211 | + |
| 212 | + x = x.mean(dim = 1) |
| 213 | + |
| 214 | + x = self.to_latent(x) |
| 215 | + return self.linear_head(x) |
| 216 | + |
| 217 | +# main |
| 218 | + |
| 219 | +if __name__ == '__main__': |
| 220 | + vit = SimpleViT( |
| 221 | + num_classes = 1000, |
| 222 | + image_size = 256, |
| 223 | + patch_size = 8, |
| 224 | + dim = 1024, |
| 225 | + depth = 12, |
| 226 | + heads = 8, |
| 227 | + mlp_dim = 2048, |
| 228 | + num_residual_streams = 8 |
| 229 | + ) |
| 230 | + |
| 231 | + images = torch.randn(3, 3, 256, 256) |
| 232 | + |
| 233 | + logits = vit(images) |
0 commit comments